Outer product approximation of Hessian matrix
Hessian matrix is heavily studied in the optimization community. The purpose is to utilize the second order derivative to optimize a function of interest (also known as Newton’s method). In machine learning, especially Bayesian inference, Hessian matrix can be found in some applications, such as Laplace’s method which approximates a distribution by a Gaussian distribution. Although Hessian matrix provides additional information which improves the convergence rate in optimization or reduces a complicated distribution to a Gaussian distribution, calculating a Hessian matrix often increases computation complexity. In neural networks where the number of model parameters is very large, Hessian matrix is often intractable due to the limited computation and memory.
Many efficient approximations of Hessian matrix have been developed to either reduce the running time complexity or decompose the Hessian matrix to reduce the amount of memory storage. Hessian-free approaches which utilizes the Hessian-vector product are also attracted much research interest. This post will present an approximation of Hessian matrix using the outer product. Note that this approximation represents an approximated Hessian matrix by a set of matrices whose sizes are reasonable to store in GPU memory. The trade-off is that the running time complexity to obtain the Hessian matrix is still quadractic. Note that this approximation is also known as Gauss-Newton matrix.
1 Notations
Before going into details, let’s define some notations used:
- \{x_{i}, t_{i}\}_{i = 1}^{N} is the input and label of data-point i-th,
- \mathbf{w} \in \mathbb{R}^{W} is the parameter of the model of interest, or the weight of a neural network,
- \ell(.) \in \mathbb{R} is the loss function, e.g. MSE or cross-entropy,
- \mathbf{f}(x_{i}, \mathbf{w}) \in \mathbb{R}^{C} is the pre-nonlinearity output of the neural network at the final layer that has C hidden units,
- \sigma\left[ \mathbf{f}\left(x_{i}, \mathbf{w}\right) \right] \in \mathbb{R}^{C} is the activation output at the final layer. For example, in regression, \sigma(z) = z is the identity function, or in logistic regression, \sigma(.) is the sigmoid function, while in multi-class classification, \sigma(.) is the softmax function,
The loss function of interest is defined as the sum of losses over each data point: L = \sum_{i = 1}^{N} \ell\left( \sigma(\mathbf{f}(x_{i}, \mathbf{w}), t_{i}\right). Note that in the following, we will omit the notation of the label t_{i} from the loss \ell(.) to make the notation unclutered.
2 Derivation of the approximated Hessian matrix
An element of the Hessian matrix can then be written as: \begin{aligned} \mathbf{H}_{jk} & = \frac{\partial}{\partial\mathbf{w}_{k}} \left( \frac{\partial L}{\partial \mathbf{w}_{j}} \right) = \frac{\partial}{\partial\mathbf{w}_{k}} \left( \sum_{i=1}^{N} \frac{\partial \ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{w}_{j}} \right) \\ & = \frac{\partial}{\partial \mathbf{w}_{k}} \left( \sum_{i=1}^{N} \sum_{c=1}^{C} \frac{\partial\ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})} \frac{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})}{\partial \mathbf{w}_{j}} \right) \quad \text{\textcolor{ForestGreen}{(chain rule)}}\\ & = \sum_{i=1}^{N} \sum_{c=1}^{C} \frac{\partial}{\partial \mathbf{w}_{k}} \left( \frac{\partial \ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})} \frac{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})}{\partial \mathbf{w}_{j}} \right). \end{aligned}
Applying the chain rule for the first term gives: \begin{aligned} \mathbf{H}_{jk} & = \sum_{i=1}^{N} \sum_{c=1}^{C} \left[ \sum_{l=1}^{C} \left( \frac{\partial^{2} \ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w}) \, \partial \mathbf{f}_{l}(x_{i}, \mathbf{w})} \frac{\partial \mathbf{f}_{l}(x_{i}, \mathbf{w})}{\partial \mathbf{w}_{k}} \right) \frac{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})}{\partial \mathbf{w}_{j}} \right] \\ & \qquad \qquad \quad + \frac{\partial \ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})} \frac{\partial^{2} \mathbf{f}_{c} (x_{i}, \mathbf{w})}{\partial \mathbf{w}_{j} \, \partial \mathbf{w}_{k}}. \end{aligned}
Rearranging gives: \begin{aligned} \mathbf{H}_{jk} & = \sum_{i=1}^{N} \sum_{c=1}^{C} \frac{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})}{\partial \mathbf{w}_{j}} \sum_{l=1}^{C} \frac{\partial^{2} \ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w}) \, \partial \mathbf{f}_{l}(x_{i}, \mathbf{w})} \frac{\partial \mathbf{f}_{l}(x_{i}, \mathbf{w})}{\partial \mathbf{w}_{k}} \\ & \quad + \sum_{i=1}^{N} \sum_{c=1}^{C} \underbrace{\frac{\partial \ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})}}_{\approx 0} \frac{\partial^{2} \mathbf{f}_{c} (x_{i}, \mathbf{w})}{\partial \mathbf{w}_{j} \, \partial \mathbf{w}_{k}}. \end{aligned}
Near the optimum, the scalar \mathbf{f}_{c} would be very closed to its target \mathbf{t}_{ic}. Hence, the derivative of the loss w.r.t. \mathbf{f}_{c} is very small, and we can approximate the Hessian as: \mathbf{H}_{jk} \approx \sum_{i=1}^{N} \sum_{c=1}^{C} \frac{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w})}{\partial \mathbf{w}_{j}} \sum_{l=1}^{C} \frac{\partial^{2} \ell \left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w}) \right)\right]}{\partial \mathbf{f}_{c} (x_{i}, \mathbf{w}) \, \partial \mathbf{f}_{l}(x_{i}, \mathbf{w})} \frac{\partial \mathbf{f}_{l}(x_{i}, \mathbf{w})}{\partial \mathbf{w}_{k}}.
Rewriting this with matrix notation yields a much simpler formulation: \boxed{ \mathbf{H} \approx \sum_{i=1}^{N} \mathbf{J}_{fi}^{\top} \mathbf{H}_{\sigma i} \mathbf{J}_{fi}, } where: \begin{aligned} \mathbf{J}_{fi} & = \nabla_{\mathbf{w}} \mathbf{f}(x_{i}, \mathbf{w}) \in \mathbb{R}^{C \times W} \quad \text{\textcolor{ForestGreen}{(Jacobian matrix of \textbf{f} w.r.t. \textbf{w})}}\\ \mathbf{H}_{\sigma i} & = \nabla_{\mathbf{f}}^{2} \ell\left[ \sigma \left( \mathbf{f}(x_{i}, \mathbf{w} \right) \right] \in \mathbb{R}^{C \times C} \quad \text{\textcolor{ForestGreen}{(Hessian of loss w.r.t. \textbf{f})}}. \end{aligned}
Note that the Hessian matrix \mathbf{H}_{\sigma} can be manually calculated.
Remark. Instead of storing the Hessian matrix \mathbf{H} with size {W \times W} which needs a large amount of memory, we can store the two matrices \{\mathbf{J}_{fi}, \mathbf{H}_{\sigma i}\}_{i=1}^{N}. This will reduce the amount of memory required. Of course, the trade-off is the increasing of the computation when performing the multiplication to obtain the Hessian matrix \mathbf{H}.
The following section will present how to calculate the matrix \mathbf{H}_{\sigma} for some commonly-used losses.
3 Derivation for \mathbf{H}_{\sigma}
3.1 Mean square error in regression
In the regression:
- C = 1
- \sigma(.) is the identity function
- \ell(f(x_{i}, \mathbf{w}) = \frac{1}{2} \left( f(x_{i}, \mathbf{w}) - t_{i} \right)^{2}.
Hence, \mathbf{H}_{\sigma} = \mathbf{I}_{1}, resulting in: \boxed{ \mathbf{H} = \sum_{i=1}^{N} \mathbf{J}_{fi}^{\top} \mathbf{J}_{fi}, } which agrees with the results in (Bishop and Nasrabadi 2006 - Eq.(5.84)).
3.2 Logistic regression
In this case:
- C = 1
- \sigma(.) is the sigmoid function
- \ell(\sigma(f(x_{i}, \mathbf{w})) = - t_{i} \ln \sigma \left( f(x_{i}, \mathbf{w}) \right) - (1 - t_{i}) \ln \left( 1 - \sigma \left( f(x_{i}, \mathbf{w}) \right) \right).
The first derivative is expressed as: \frac{\partial \ell(\sigma(f(x_{i}, \mathbf{w}))}{\partial f(x_{i}, \mathbf{w})} = - t_{i} \left( 1 - \sigma \left( f(x_{i}, \mathbf{w}) \right) \right) + (1 - t_{i}) \sigma \left( f(x_{i}, \mathbf{w}) \right) = \sigma \left( f(x_{i}, \mathbf{w}) \right) - t_{i}.
The second derivative is therefore: \frac{\partial^{2} \ell(\sigma(f(x_{i}, \mathbf{w}))}{\partial f(x_{i}, \mathbf{w})^{2}} = \sigma \left( f(x_{i}, \mathbf{w}) \right) \left[ 1 - \sigma \left( f(x_{i}, \mathbf{w}) \right) \right].
Hence: \boxed{ \mathbf{H} \approx \sum_{i=1}^{n} \sigma \left( f(x_{i}, \mathbf{w}) \right) \left[ 1 - \sigma \left( f(x_{i}, \mathbf{w}) \right) \right] \mathbf{J}_{fi}^{\top} \mathbf{J}_{fi}, } which agrees with the result derived in the literature (Bishop and Nasrabadi 2006 - Eq. (5.85)).
3.3 Cross entropy loss in classification
In this case:
- \sigma(\mathbf{f}) is the softmax function,
- \ell(\sigma(\mathbf{f}(x_{i}, \mathbf{w}))) = -\sum_{c=1}^{C} \mathbf{t}_{ic} \ln \sigma_{c}(\mathbf{f}(x_{i}, \mathbf{w})).
According to the definition of the softmax function: \sigma_{c} \left( \mathbf{f} \right) = \frac{\exp(\mathbf{f}_{c})}{\sum_{k=1}^{C} \exp(\mathbf{f}_{k})}.
Hence, the derivative can be written as: \frac{\partial \sigma_{c}(\mathbf{f})}{\partial \mathbf{f}_{c}} = \frac{\exp(\mathbf{f}_{c}) \sum_{k=1}^{C} \exp(\mathbf{f}_{k}) - \exp(2 \mathbf{f}_{c})}{\left[ \sum_{k=1}^{C} \exp(\mathbf{f}_{k}) \right]^{2}} = \sigma_{c}(\mathbf{f}) \left[ 1 - \sigma_{c}(\mathbf{f}) \right], and \frac{\partial \sigma_{c}(\mathbf{f})}{\partial \mathbf{f}_{k}} = - \sigma_{c}(\mathbf{f}) \sigma_{k}(\mathbf{f}), \forall k \neq j.
An element of the Jacobian vector of the loss w.r.t. \mathbf{f} can be written as: \begin{aligned} \frac{\partial \ell(\sigma(\mathbf{f}(x_{i}, \mathbf{w})))}{\partial \mathbf{f}_{c}(x_{i}, \mathbf{w})} & = - \sum_{k=1}^{C} \frac{\mathbf{t}_{ik}}{\sigma_{k}(\mathbf{f})} \frac{\partial \sigma_{k}(\mathbf{f})}{\partial \mathbf{f}_{c}} \\ & = - \mathbf{t}_{ic} \left[ 1 - \sigma_{c}(\mathbf{f}) \right] + \sum_{\substack{k=1\\k \neq c}}^{C} \mathbf{t}_{ik} \sigma_{c}(\mathbf{f}) \\ & = - \mathbf{t}_{ic} + \sigma_{c}(\mathbf{f}) \underbrace{\sum_{k=1}^{C} \mathbf{t}_{ik}}_{1}\\ & = \sigma_{c}(\mathbf{f}) - \mathbf{t}_{ic}. \end{aligned}
Hence, the Jacobian vector can be expressed as: \nabla_{\mathbf{f}} \ell(\sigma(\mathbf{f}(x_{i}, \mathbf{w}))) = \sigma(\mathbf{f}(x_{i}, \mathbf{w})) - \mathbf{t}_{i}.
The Hessian matrix is given as: \nabla_{\mathbf{f}}^{2} \ell(\sigma(\mathbf{f}(x_{i}, \mathbf{w}))) = \nabla_{\mathbf{f}} \sigma(\mathbf{f}(x_{i}, \mathbf{w})).
Or, in the explicit matrix form: \mathbf{H}_{\sigma} = \begin{bmatrix} \sigma_{1}(\mathbf{f}) \left[ 1 - \sigma_{1}(\mathbf{f}) \right] & - \sigma_{1}(\mathbf{f}) \sigma_{2}(\mathbf{f}) & - \sigma_{1}(\mathbf{f}) \sigma_{3}(\mathbf{f}) & \ldots & - \sigma_{1}(\mathbf{f}) \sigma_{C}(\mathbf{f})\\ - \sigma_{2}(\mathbf{f}) \sigma_{1}(\mathbf{f}) & \sigma_{2}(\mathbf{f}) \left[ 1 - \sigma_{2}(\mathbf{f}) \right] & - \sigma_{2}(\mathbf{f}) \sigma_{3}(\mathbf{f}) & \ldots & - \sigma_{2}(\mathbf{f}) \sigma_{C}(\mathbf{f})\\ \vdots & \vdots & \ddots & \vdots & \vdots\\ - \sigma_{C}(\mathbf{f}) \sigma_{1}(\mathbf{f}) & - \sigma_{C}(\mathbf{f}) \sigma_{2}(\mathbf{f}) & - \sigma_{C}(\mathbf{f}) \sigma_{3}(\mathbf{f}) & \ldots & \sigma_{C}(\mathbf{f}) \left[ 1 - \sigma_{C}(\mathbf{f}) \right] \end{bmatrix}.
4 Conclusion
In this post, we derive an approximation of the Hessian matrix. The Gauss-Newton matrix is a good approximation since it is positive-definite and more efficient to store under the form of a set of smaller matrices. Of course, we have not got away from the curse of dimensionality since the running time complexity to obtain the Hessian matrix is still quadratic w.r.t. the number of the model parameters. One final note is that one should use the approximated Hessian matrix with care since the approximation is assumed to be near the minimal value of the considered loss function.
5 References
Reuse
Citation
@online{nguyen2021,
author = {Nguyen, Cuong},
title = {Outer Product Approximation of {Hessian} Matrix},
date = {2021-04-12},
url = {https://cnguyen10.github.io/posts/Gauss-Newton-matrix/},
langid = {en}
}