I first did this rederivation when trying to understand EWC in 2019, when experimenting with continual learning at SDSUG AIML lab with Davit Soselia and others. No guarantees that everything is correct, but it mostly seems to check out.
-
Data-generating distribution: The mixture of the underlying distributions of the dataset from the first task and the dataset from the second task.
-
$W = w$ : The event that our current parameters$w$ are optimal for our model's ability to model the data-generating distribution. -
$x$ : The event that some data point or collection of data points$x$ has been generated from the data-generating distribution. -
$D_a$ : The event that the data in some array$D_a$ has been generated from the data-generating distribution. -
$D_b$ : The event that the data in some array$D_b$ has been generated from the data-generating distribution. -
$D$ : The event that the data in some array$D$ (which is the union of the arrays$D_a$ and$D_b$ ) has been generated from the data-generating distribution. This can be expressed as:
-
We will use a flat prior for
$P(W=w)$ . This means that when we have seen no data, we assume that any value of$w$ is just as likely to be the optimal weights for our data-generating distribution as any other. -
We will assume that the non-diagonal elements of the Hessian of the model's negative log-likelihood on the data are negligible. In other words, we assume that the variation in the loss with respect to any given parameter
$w_i$ does not significantly depend on any other parameter$w_j$ . Mathematically, this means that:
Intuitively, this assumption means that how the loss varies with any given parameter
-
$D_a$ and$D_b$ are conditionally independent given$W=w$ . This can be written as:
Intuitively, this assumption makes sense because the contents of
- In supervised learning terms, if we imagine the datasets as key-value pairs (with samples as keys and labels as values),
$D_a$ and$D_b$ will not share keys or have similar keys to each other.
-
All gradients and Hessians are with respect to
$w$ . -
$N$ is the number of samples$x$ in$D_a$ . -
$M$ is the dimension of$w$ .
Let us consider the probability of
According to Bayes' Theorem:
Expanding the definition of
From our presupposition, we know that
Substituting this into the previous equation:
Rearranging the terms:
We know that:
Thus, we can rewrite the equation as:
We also know that:
Therefore, we get:
Next, we know that:
Thus, we can simplify it to:
For visibility, let’s take the logarithm of both sides:
Rearranging the terms:
-
$\log P(D_b \mid W = w)$ is the log-likelihood of our model on dataset$D_b$ . Thus, the probability of$W$ being optimal on$D$ depends on the log-likelihood of our model on dataset$D_b$ . -
$\log P(W = w \mid D_a)$ is an intractable posterior term that represents the probability of$W$ being optimal given that$D_a$ has been generated from the data-generating distribution. - Finally, we don't care about the
$\log P(D_b \mid D_a)$ term because it is independent of and not conditioned on the probability of our weights being optimal, so it's a constant regardless of what parameters we have.
In other words, to maximize
- Maximize
$P(D_b \mid W = w)$ - Maximize
$P(W = w \mid D_a)$
To achieve the first, we need to train the model with a negative log-likelihood loss on dataset
The second term,
Let’s suppose that we have randomly initialized our model and then trained it with a negative log-likelihood loss on
Thus, we have:
Furthermore, we know that the array
Taking the logarithm of both sides, we get:
The gradient of the sum can be written as:
Dividing by
Thus, we have:
Next, we know that:
Thus, the negative log of
Taking the gradient of both sides:
Since
We also know that:
Thus, the negative log of
Taking the gradient of both sides:
Since we are using a flat prior for
Therefore, we conclude that:
As such, we have:
Similarly, for the Hessian:
From Lemma 3 and Lemma 1, we get:
From Lemma 4, we get:
The Laplace approximation procedure here follows the one outlined in chapter 4.4 of Pattern Recognition and Machine Learning (Bishop).
Let’s approximate
From Lemma 5, we know that
Let
Taking the exponential of both sides, we get:
Next, we approximate
Substituting back
From Lemma 6, we know that:
Thus:
We know that
Assuming that the non-diagonal elements of this covariance matrix are zero, we get:
The variance of
From Lemma 2, we know that:
Thus, the variance simplifies to:
Therefore, we get:
Thus, the Laplace approximation for
Let
Let
Taking the logarithm of both sides, we get:
Since
Returning to our expression for
This expression is maximized by minimizing the following loss function:
The second term can be viewed as a quadratic regularization term. To control its influence on the training process, we introduce a regularization parameter
(Note that in the original paper,
To add L2 weight decay, we can include the decay term:
Let us introduce a new task
We have already derived an approximation for
This simplifies to:
After training on
- 0th degree term:
- 1st degree term:
This term is zero because the gradient of the expression is zero at
-
2nd degree term:
- The Hessian of the constant part with 0 gradient is zero.
- The Hessian of
$\frac{1}{2} (w - w_a^*)^T \text{diag} \left( \sum_{x \in D_a} \nabla (-\log P(x \mid W = w_a^*))^2 \right) (w - w_a^*)$ is$\text{diag} \left( \sum_{x \in D_a} \nabla (-\log P(x \mid W = w_a^*))^2 \right)$ . - The Hessian of
$\log P(D_b \mid W = w)$ is the same as the Hessian of$\log P(W = w \mid D_b)$ (Lemma 4). Thus, it can be approximated as$\text{diag} \left( \sum_{x \in D_b} \nabla (-\log P(x \mid W = w_b^*))^2 \right)$ .
The complete second-degree term is:
Thus, the full second-degree Taylor approximation becomes:
Therefore:
Generalizing, let
The loss becomes:
The EWC loss becomes:
Adding per-task
As outlined by Ferenc Huszár, this is different from the loss recommended by the original paper for multiple tasks. In the original paper, a new quadratic penalty is added for each task: