<?xml version="1.0" encoding="UTF-8"?>
<rss  xmlns:atom="http://www.w3.org/2005/Atom" 
      xmlns:media="http://search.yahoo.com/mrss/" 
      xmlns:content="http://purl.org/rss/1.0/modules/content/" 
      xmlns:dc="http://purl.org/dc/elements/1.1/" 
      version="2.0">
<channel>
<title>Cuong Nguyen</title>
<link>https://cnguyen10.github.io/blog.html</link>
<atom:link href="https://cnguyen10.github.io/blog.xml" rel="self" type="application/rss+xml"/>
<description>Probabilistic machine learning</description>
<generator>quarto-1.9.38</generator>
<lastBuildDate>Sun, 19 Nov 2023 00:00:00 GMT</lastBuildDate>
<item>
  <title>Stochastic gradient and Hamiltonian Monte Carlo</title>
  <dc:creator>Cuong Nguyen</dc:creator>
  <link>https://cnguyen10.github.io/posts/stochastic_grad_hamiltonian_monte_carlo/</link>
  <description><![CDATA[ 




<p>This post is to introduce the formulation of stochastic gradient descent as a Monte Carlo sampling to approximate the posterior of the variables of interest.</p>
<section id="motivation-of-monte-carlo-sampling" class="level2" data-number="1">
<h2 data-number="1" class="anchored" data-anchor-id="motivation-of-monte-carlo-sampling"><span class="header-section-number">1</span> Motivation of Monte Carlo sampling</h2>
<p>According to <span class="citation" data-cites="mackay2003information">(MacKay 2003, chap. 29)</span>, Monte Carlo based methods make use of random numbers (or in particular, random variables) to solve one or both of the following problems.</p>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Problem 1 - generate samples
</div>
</div>
<div class="callout-body-container callout-body">
<p>Generate samples <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Ctheta%5E%7B(r)%7D%5C%7D_%7Br%20=%201%7D%5E%7BR%7D"> from a given probability distribution <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)">.</p>
</div>
</div>
<div class="callout callout-style-default callout-important callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Important</span>Problem 2 - estimate an expected value
</div>
</div>
<div class="callout-body-container callout-body">
<p>Estimate the expectation of a given function <img src="https://latex.codecogs.com/png.latex?%5Cell(%5Ctheta)"> under a given distribution <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)">: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Coverline%7B%5Cell%7D%20=%20%5Cint%20%5Cell(%5Ctheta)%20%5C,%20P(%5Ctheta)%20%5C,%20%5Coperatorname%7Bd%7D%5E%7BN%7D%20%5Ctheta,%0A"> where <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> is assumed to be an <img src="https://latex.codecogs.com/png.latex?N">-dimensional vector with real components <img src="https://latex.codecogs.com/png.latex?%5Ctheta_%7Bn%7D">.</p>
</div>
</div>
<p>It is assumed that <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)"> is sufficiently complex that we cannot either <em>(i)</em> sample from it by some conventional techniques, and <em>(ii)</em> evaluate those expectations by exact methods. That motivates us to study Monte Carlo approximation methods.</p>
<p>Majority of studies in Monte Carlo methods focus on the first problem (sampling) because if we have solved the first problem, then we can solve the second problem by using the Monte Carlo approximation to give an estimation about the expectation: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Chat%7B%5Cell%7D%20=%20%5Cfrac%7B1%7D%7BR%7D%20%5Csum_%7Br%20=%201%7D%5E%7BR%7D%20%5Cell(%5Ctheta%5E%7B(r)%7D),%0A"> where: <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Ctheta%5E%7B(r)%7D%5C%7D_%7Br%20=%201%7D%5E%7BR%7D"> are generated from <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)">.</p>
<p>Under this approximation, <img src="https://latex.codecogs.com/png.latex?%5Chat%7B%5Cell%7D"> is an un-biased estimator of the exact expectation <img src="https://latex.codecogs.com/png.latex?%5Coverline%7B%5Cell%7D">.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Why is sampling from <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)"> hard?
</div>
</div>
<div class="callout-body-container callout-body">
<p>We will assume that the density from which we wish to draw samples, <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)">, can be evaluated, at least to within a multiplicative constant. In other words, we can evaluate a function <img src="https://latex.codecogs.com/png.latex?P%5E%7B*%7D(%5Ctheta)"> such that: <span id="eq-exact-distribution"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20P(%5Ctheta)%20=%20%5Cfrac%7BP%5E%7B*%7D(%5Ctheta)%7D%7BZ%7D,%0A%5Ctag%7B1%7D"></span> where <img src="https://latex.codecogs.com/png.latex?Z"> is the normalising constant (that we do not know): <span id="eq-normalising-constant"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20Z%20=%20%5Cint%20P%5E%7B*%7D(%5Ctheta)%20%5C,%20%5Coperatorname%7Bd%7D%5E%7BN%7D%5Ctheta.%0A%5Ctag%7B2%7D"></span> Thus, it is hard to draw samples from <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)"> since <img src="https://latex.codecogs.com/png.latex?Z"> is often assumed to be unknown. Even if we know <img src="https://latex.codecogs.com/png.latex?Z">, drawing samples from <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)"> is still challenging problem, especially in high-dimensional spaces because there is no obvious way to sample from <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)"> without enumerating all of the possible states.</p>
</div>
</div>
<p>There are various sampling techniques to generate samples from a given distribution, such as <em>important sampling</em>, <em>rejection sampling</em> or <em>Metropolis - Hastings</em> method. Here, we focus on a specific method, known as <em>Hamiltonian Monte Carlo</em>, which belongs to the family of the <em>Metropolis - Hastings</em> method.</p>
</section>
<section id="the-metropolis---hastings-method" class="level2" data-number="2">
<h2 data-number="2" class="anchored" data-anchor-id="the-metropolis---hastings-method"><span class="header-section-number">2</span> The Metropolis - Hastings method</h2>
<p>The Metropolis - Hastings algorithm uses a proposal density <img src="https://latex.codecogs.com/png.latex?Q(%5Ctheta%20%7C%20%5Ctheta%5E%7B(t)%7D)"> which depends on the current state <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t)%7D">. For example, <img src="https://latex.codecogs.com/png.latex?Q(%5Ctheta;%20%5Ctheta%5E%7B(t)%7D)"> might be a simple Gaussian distribution centred on the current <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t)%7D">. The proposal density <img src="https://latex.codecogs.com/png.latex?Q(%5Ctheta;%20%5Ctheta%5E%7B(t)%7D)"> can be any fixed probability distribution from which we can easily sample.</p>
<p>As before, it is assumed that the un-normalised probability <img src="https://latex.codecogs.com/png.latex?P%5E%7B*%7D(%5Ctheta)"> can be evaluated for any <img src="https://latex.codecogs.com/png.latex?%5Ctheta">. One can generate the next state <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B%5Cprime%7D"> from the proposal distribution <img src="https://latex.codecogs.com/png.latex?Q(%5Ctheta;%20%5Ctheta%5E%7B(t)%7D)">. To decide whether to accept the new state, a quantity (also known as Metropolis - Hastings score) is calculated. Depending on the value of the score, the next state can be <em>(i)</em> accepted, or <em>(ii)</em> accepted with certain probability depending on the value of the score.</p>
<ul>
<li>If the step is accepted, then <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t%20+%201)%7D%20=%20%5Ctheta%5E%7B%5Cprime%7D">.</li>
<li>Otherwise, the previous state is kept: <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t%20+%201)%7D%20=%20%5Ctheta%5E%7B(t)%7D">.</li>
</ul>
<p>The details of the Metropolis - Hastings algorithm can be seen in Algorithm 1.</p>
<div id="algo-metropolis-hastings" class="pseudocode-container quarto-float" data-indent-size="1.2em" data-line-number="true" data-caption-prefix="Algorithm" data-comment-delimiter="//" data-pseudocode-number="1" data-no-end="false" data-line-number-punc=":">
<div class="pseudocode">
\begin{algorithm} \caption{The Metropolis - Hastings sampling method} \begin{algorithmic} \Procedure{Metropolis-Hastings}{$P^{*}(\theta), Q(\theta; \theta^{(t)})$} \State initialise $\theta^{(0)}$ \While{$t = 0, 1, \dots, T, \dots, T_{\mathrm{end}}$} \State $\theta^{\prime} \gets$ \Call{sample-from-proposal-distribution}{$Q(\theta; \theta^{(t)})$} \Comment{generate a new state} \State $a \gets \displaystyle \frac{p^{*}(\theta^{\prime})}{p^{*}(\theta^{(t)})} \frac{q(\theta^{(t)}; \theta^{\prime})}{q(\theta^{\prime}; \theta^{(t)})}$ \Comment{calculate Metropolis - Hastings score} \If{$a \ge 1$} \State $\theta^{(t + 1)} \gets \theta^{\prime}$ \Comment{accept the new state} \Else \State $\theta^{(t + 1)} \gets \theta^{(t)}$ \Comment{reject the new state} \EndIf \EndWhile \State return $\{\theta^{(t)}\}_{t = T}^{T_{\mathrm{end}}}$ \EndProcedure \end{algorithmic} \end{algorithm}
</div>
</div>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Different from rejection sampling
</div>
</div>
<div class="callout-body-container callout-body">
<p>In rejection sampling, rejected points are discarded and have no influence on the list of samples <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Ctheta%5E%7B(r)%7D%5C%7D"> that are collected to represent the distribution <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)">. In Metropolis - Hastings method, although rejected points are also discarded, the difference is that a rejection causes the current state <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t)%7D"> to be written again onto the list.</p>
</div>
</div>
<p><strong>Convergence of the Metropolis - Hastings method</strong> &nbsp; It has been shown that for any positive proposal distribution, i.e., <img src="https://latex.codecogs.com/png.latex?Q(%5Ctheta;%20%5Ctheta%5E%7B(t)%7D)%20%3E%200,%20%5Cforall%20%5Ctheta,%20%5Ctheta%5E%7B(t)%7D">, as <img src="https://latex.codecogs.com/png.latex?t%5Cto+%5Cinfin">, the probability distribution of <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t)%7D"> converges to its true distribution <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)"> defined in Equation&nbsp;1.</p>
<div class="callout callout-style-default callout-warning callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Warning</span>Dependency of samples generated from the Metropolis - Hastings method
</div>
</div>
<div class="callout-body-container callout-body">
<p>The Metropolis - Hastings method is an example of a <em>Markov chain Monte Carlo</em> method (abbreviated MCMC). In MCMC methods, a Markov process is employed to generate a sequence of states <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Ctheta%5C%7D">, where each sample <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t)%7D"> has a probability distribution depend on the previous state, <img src="https://latex.codecogs.com/png.latex?%5Ctheta%5E%7B(t%20-%201)%7D">. And because successive samples are dependent, the Markov chain may need to be run for a considerable amount of time to effectively generate independent samples from the hidden distribution <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)">.</p>
</div>
</div>
</section>
<section id="the-hamiltonian-monte-carlo-method" class="level2" data-number="3">
<h2 data-number="3" class="anchored" data-anchor-id="the-hamiltonian-monte-carlo-method"><span class="header-section-number">3</span> The Hamiltonian Monte Carlo method</h2>
<p>The Hamiltonian Monte Carlo method is an instance of the Metropolis - Hastings method that is applicable to continuous domain. It makes use of gradient information to reduce random walk behaviour, potentially resulting in a more efficient MCMC method. In particular, it replaces the proposal distribution <img src="https://latex.codecogs.com/png.latex?Q(%5Ctheta;%20%5Ctheta%5E%7B(t)%7D)"> by an implicit distribution in the form of a differential equation.</p>
<p>Similar to the Metropolis - Hastings method, we assume that the density <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)"> is known up to a normalised constant and written in the form of the <em>potential energy</em> <img src="https://latex.codecogs.com/png.latex?U(%5Ctheta)"> as follows: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20P(%5Ctheta)%20=%20%5Cfrac%7B%5Cexp(-U(%5Ctheta))%7D%7BZ%7D.%0A"></p>
<p>The <em>potential energy</em>, <img src="https://latex.codecogs.com/png.latex?U(%5Ctheta)">, is defined as: <span id="eq-potential-energy"><img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20U(%5Ctheta)%20=%20-%20%5Csum_%7Bx%20%5Cin%20%5Cmathcal%7BD%7D%7D%20%5Cln%20p(x%20%7C%20%5Ctheta)%20-%20%5Cln%20p(%5Ctheta),%0A%7D%0A%5Ctag%7B3%7D"></span> where <img src="https://latex.codecogs.com/png.latex?p(x%20%7C%20%5Ctheta)"> is a likelihood function, and <img src="https://latex.codecogs.com/png.latex?p(%5Ctheta)"> is the prior distribution of <img src="https://latex.codecogs.com/png.latex?%5Ctheta">.</p>
<p>The Hamiltonian Monte Carlo method augments the variable of interest, <img src="https://latex.codecogs.com/png.latex?%5Ctheta">, by an <img src="https://latex.codecogs.com/png.latex?N_%7B%5Crho%7D">-dimensional <em>momentum variables</em> vector <img src="https://latex.codecogs.com/png.latex?%5Crho">. A common analogy is that <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> is the position, while <img src="https://latex.codecogs.com/png.latex?%5Crho"> is the velocity of an object of interest. In that case, the <em>kinetic energy</em> <img src="https://latex.codecogs.com/png.latex?K(%5Crho)"> is defined as follows: <span id="eq-kinetic-energy"><img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20K(%5Crho)%20=%20%5Cfrac%7B1%7D%7B2%7D%20%5Crho%5E%7B%5Ctop%7D%20M%5E%7B-1%7D%20%5Crho,%0A%7D%0A%5Ctag%7B4%7D"></span> where <img src="https://latex.codecogs.com/png.latex?M%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BN_%7B%5Crho%7D%20%5Ctimes%20N_%7B%5Crho%7D%7D"> is symmetric positive definite matrix known as <em>mass matrix</em>.</p>
<p>The Hamiltonian dynamics of the whole system can then be defined as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20H(%5Ctheta,%20%5Crho)%20=%20U(%5Ctheta)%20+%20K(%5Crho).%0A"></p>
<p>One can then define the joint probability density as: <span id="eq-joint-distribution"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20p_%7BH%7D(%5Ctheta,%20%5Crho)%20=%20%5Cfrac%7B%5Cexp(-H(%5Ctheta,%20%5Crho))%7D%7BZ_%7BH%7D%7D%20=%20%5Cfrac%7B1%7D%7BZ_%7BH%7D%7D%20%5Cexp(-U(%5Ctheta))%20%5C,%20%5Cexp(-K(%5Crho)).%0A%5Ctag%7B5%7D"></span></p>
<p>Since the probability distribution <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D"> is separable, the marginal distribution of <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> is the desired distribution <img src="https://latex.codecogs.com/png.latex?p(%5Ctheta)%20=%20%5Cfrac%7B%5Cexp(-U(%5Ctheta))%7D%7BZ%7D">. Thus, simply discarding the momentum variables <img src="https://latex.codecogs.com/png.latex?%5Crho"> would allow to obtain a sequence of samples <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Ctheta%5E%7B(t)%7D%5C%7D"> that asymptotically come from <img src="https://latex.codecogs.com/png.latex?P(%5Ctheta)">.</p>
<p>The characteristics of a Hamiltonian dynamics can be written as: <span id="eq-hamiltonian-dynamics"><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Bdcases%7D%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%5Ctheta%7D%7B%5Coperatorname%7Bd%7Dt%7D%20&amp;%20=%20%5Cfrac%7B%5Cpartial%20H(%5Ctheta,%20%5Crho)%7D%7B%5Cpartial%20%5Crho%7D%20=%20M%5E%7B-1%7D%20%5Crho%20%5C%5C%0A%20%20%20%20&amp;%20%5C%5C%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%5Crho%7D%7B%5Coperatorname%7Bd%7Dt%7D%20&amp;%20=%20-%20%5Cfrac%7B%5Cpartial%20H(%5Ctheta,%20%5Crho)%7D%7B%5Cpartial%20%5Ctheta%7D%20=%20-%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta).%0A%5Cend%7Bdcases%7D%0A%5Ctag%7B6%7D"></span></p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>2D analogy of the Hamiltonian dynamics <span class="citation" data-cites="chen2014stochastic">(Chen et al. 2014)</span>
</div>
</div>
<div class="callout-body-container callout-body">
<p>To analogise the Hamiltonian dynamics, one can imagine a hockey puck sliding over a frictionless ice surface of varying height. The potential energy is proportional to the height of the surface at the current position, <img src="https://latex.codecogs.com/png.latex?%5Ctheta">, of the puck, while the kinectic energy is proportional to the momentum, <img src="https://latex.codecogs.com/png.latex?%5Crho">, and the mass, <img src="https://latex.codecogs.com/png.latex?M">, of the hockey puck.</p>
<p>If the surface is flat: <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20=%200,"> then the hockey puck will move at a constant speed.</p>
<p>If it is going uphill (positive slope: <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%3E%200">), the kinetic energy decreases as the potential energy increases util the kinetic reaches 0 (equivalently, <img src="https://latex.codecogs.com/png.latex?%5Crho%20=%200">). The hockey puck stops in an instant and begins to slide back down the hill, resulting in increasing the kinectic energy and decreasing the potential energy.</p>
</div>
</div>
<p>Equation&nbsp;6 defines the transformation of the two variables <img src="https://latex.codecogs.com/png.latex?(%5Ctheta,%20%5Crho)"> from time <img src="https://latex.codecogs.com/png.latex?t"> to time <img src="https://latex.codecogs.com/png.latex?t%20+%20%5CDelta%20t."> This transformation is <em>reversible</em>. Moreover, the Hamiltonian is invariant (or the preservation of the Hamiltonian <img src="https://latex.codecogs.com/png.latex?H(%5Ctheta,%20%5Crho)">): <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20H%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Ctheta_%7Bi%7D%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Ctheta_%7Bi%7D%7D%20+%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Crho_%7Bi%7D%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Crho_%7Bi%7D%7D%20=%20%5Csum_%7Bd%20=%201%7D%5E%7BN%7D%20%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Crho_%7Bi%7D%7D%20%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Ctheta_%7Bi%7D%7D%20-%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Ctheta_%7Bi%7D%7D%20%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Crho_%7Bi%7D%7D%20=%200.%0A"></p>
<p>This makes any proposal <img src="https://latex.codecogs.com/png.latex?(%5Ctheta,%20%5Crho)"> obtained from such a perfect simulation always acceptable. If the simulation is imperfect, due to the finite step size when performing the integration for example, then some of the dynamical proposals will be rejected. The rejection rule makes use of the change in <img src="https://latex.codecogs.com/png.latex?H(%5Ctheta,%20%5Crho)">, which is zero if the simulation is perfect. Please refer to Algorithm 2 for further details of the Hamiltonian Monte Carlo method.</p>
<div id="algo-hamiltonian-mc" class="pseudocode-container quarto-float" data-indent-size="1.2em" data-line-number="true" data-caption-prefix="Algorithm" data-comment-delimiter="//" data-pseudocode-number="2" data-no-end="false" data-line-number-punc=":">
<div class="pseudocode">
\begin{algorithm} \caption{Hamiltonian Monte Carlo method} \begin{algorithmic} \Procedure{Hamiltonian-MC}{$U(.), M, \varepsilon$} \State initialise $\theta^{(1)}$ \While{$t = 1, 2, \dots, T, \dots, T_{\mathrm{end}}$} \State sample momentum: $\rho^{(t)} \sim \mathcal{N}(0, M^{-1})$ \State evaluate total energy: $H \gets U(\theta^{(t)}) + K(\rho^{(t)})$ \State $\theta^{(t, 1)} \gets \theta^{(t)}$ \State $\rho^{(t, 1)} \gets \rho^{(t)}$ \For{$i = 1, 2, \dots, \tau$} \Comment{Simulate for next state} \State $\rho^{(t, i + \frac{1}{2})} \gets \rho^{(t, i)} - \frac{1}{2} \varepsilon \nabla_{\theta} U(\theta^{(t, i)})$ \Comment{make a half-step in $\rho$} \State $\theta^{(t, i + 1)} \gets \theta^{(t, i)} + \varepsilon M^{-1} \rho^{(t, i + \frac{1}{2})}$ \Comment{make a step in $\theta$} \State $\rho^{(t, i + 1)} \gets \rho^{(t, i + \frac{1}{2})} - \frac{1}{2} \varepsilon \nabla_{\theta} U(\theta^{(t, i)})$ \Comment{make another half-step in $\rho$} \EndFor \State $\theta^{\prime} \gets \theta^{(t, \tau)}$ \Comment{new state of $\theta$} \State $\rho^{\prime} \gets \rho^{(t, \tau)}$ \Comment{new state of momentum} \State evaluate total energy with the new state: $H_{\mathrm{new}} \gets U(\theta^{\prime}) + K(\rho^{\prime})$ \State calculate: $\operatorname{d}H \gets H_{\mathrm{new}} - H$ \State sample: $u \sim \mathrm{uniform}(0, 1)$ \If{$u &lt; \exp(-\operatorname{d}H)$} \Comment{Metropolis - Hastings step} \State $\theta^{(t + 1)} \gets \theta^{\prime}$ \Comment{accept the new state} \Else \State $\theta^{(t + 1)} \gets \theta^{(t)}$ \Comment{reject the new state} \EndIf \EndWhile \State return $\{\theta^{(t)}\}_{t = T}^{T_{\mathrm{end}}}$ \EndProcedure \end{algorithmic} \end{algorithm}
</div>
</div>
<p>Despite its efficiency, the Hamiltonian Monte Carlo method still requires to run through the <em>entire</em> dataset to perform the integration for <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> as well as the Metropolis - Hastings step to decide whether to accept or reject the new state generated from the Hamiltonian dynamics. Hence, in the lense of machine learning, it is, however, impractical, especially for large-scaled datasets. It, therefore, motivates further studies and development to make the method practical.</p>
</section>
<section id="stochastic-gradient-hamiltonian-monte-carlo" class="level2" data-number="4">
<h2 data-number="4" class="anchored" data-anchor-id="stochastic-gradient-hamiltonian-monte-carlo"><span class="header-section-number">4</span> Stochastic gradient Hamiltonian Monte Carlo</h2>
<p>To reduce the cost calculating <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)"> on the entire dataset <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D">, stochastic versions of Hamiltonian Monte Carlo are proposed in <span class="citation" data-cites="welling2011bayesian chen2014stochastic">(Welling and Teh 2011; Chen et al. 2014)</span>. In this case, the <em>whole-batch</em> gradient, <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)">, is estimated by a noisy estimator, <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Ctheta%7D%20%5Ctilde%7BU%7D(%5Ctheta)">, which is based on a single mini-batch, <img src="https://latex.codecogs.com/png.latex?%5Ctilde%7B%5Cmathcal%7BD%7D%7D">, of data. Such a noisy estimator can be written as follows: <span id="eq-noisey-potential-energy"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cnabla_%7B%5Ctheta%7D%20%5Ctilde%7BU%7D(%5Ctheta)%20=%20-%20%5Cfrac%7B%7C%5Cmathcal%7BD%7D%7C%7D%7B%7C%5Ctilde%7B%5Cmathcal%7BD%7D%7C%7D%7D%20%5Csum_%7Bx%20%5Cin%20%5Ctilde%7B%5Cmathcal%7BD%7D%7D%7D%20%5Cln%20p(x%20%7C%20%5Ctheta)%20-%20%5Cln%20p(%5Ctheta).%0A%5Ctag%7B7%7D"></span></p>
<p>If there are many mini-batches, we can apply the <em>Central Limit Theorem</em> to approximate the noisy gradient of the potential energy as follows: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cnabla_%7B%5Ctheta%7D%20%5Ctilde%7BU%7D(%5Ctheta)%20%5Capprox%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20+%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon,%20%5Cquad%20%5Cepsilon%20%5Csim%20%5Cmathcal%7BN%7D(0,%20I),%0A"> where <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)"> is the covariance matrix of the stochastic gradient noise <span class="citation" data-cites="welling2011bayesian">(Welling and Teh 2011, Eq. (6))</span>: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20V(%5Ctheta)%20=%20%5Cmathbb%7BE%7D_%7B%5Ctext%7Bmini-batch%20of%20%7D%20x%20%5Cin%20%5Ctilde%7B%5Cmathcal%7BD%7D%7D%7D%20%5Cleft%5B%20%5Cnabla_%7B%5Ctheta%7D%20%5Ctilde%7BU%7D(%5Ctheta)%20%5C,%20%5Cnabla_%7B%5Ctheta%7D%5E%7B%5Ctop%7D%20%5Ctilde%7BU%7D(%5Ctheta)%20%5Cright%5D%20-%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%5C,%20%5Cnabla_%7B%5Ctheta%7D%5E%7B%5Ctop%7D%20U(%5Ctheta),%0A"> and <img src="https://latex.codecogs.com/png.latex?%5Csqrt%7BV%7D(%5Ctheta)"> denotes the matrix such that <img src="https://latex.codecogs.com/png.latex?%5Csqrt%7BV(%5Ctheta)%7D%20%5Cleft(%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cright)%5E%7B%5Ctop%7D%20=%20V(%5Ctheta)"> (e.g., Cholesky decomposition).</p>
<section id="naive-stochastic-gradient-hamiltonian-monte-carlo" class="level3" data-number="4.1">
<h3 data-number="4.1" class="anchored" data-anchor-id="naive-stochastic-gradient-hamiltonian-monte-carlo"><span class="header-section-number">4.1</span> Naive stochastic gradient Hamiltonian Monte Carlo</h3>
<p>A naive way is to directly substitute the noisy estimator in Equation&nbsp;7 into the Hamiltonian dynamics in Equation&nbsp;6: <span id="eq-noisy-hamiltonian-dynamics"><img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20%5Cbegin%7Bdcases%7D%0A%20%20%20%20%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Ctheta%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20&amp;%20=%20%20M%5E%7B-1%7D%20%5Crho%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5C%5C%0A%20%20%20%20%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Crho%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20&amp;%20=%20-%5Cnabla_%7B%5Ctheta%7D%20%5Ctilde%7BU%7D(%5Ctheta)%20=%20-%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20+%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon,%20%5Cquad%20%5Cepsilon%20%5Csim%20%5Cmathcal%7BN%7D(0,%20I).%0A%20%20%20%20%5Cend%7Bdcases%7D%0A%7D%0A%5Ctag%7B8%7D"></span></p>
<p>In this case, the Hamiltonian is not guaranteed to be invariant: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20H%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20&amp;%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cfrac%7B%5Coperatorname%7Bd%7D%5Ctheta_%7Bi%7D%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Ctheta_%7Bi%7D%7D%20+%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Crho_%7Bi%7D%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cfrac%7B%5Cpartial%20H%7D%7B%5Cpartial%20%5Crho_%7Bi%7D%7D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20(M%5E%7B-1%7D%20%5Crho)_%7Bi%7D%20%5C,%20%5Cfrac%7B%5Cpartial%20U(%5Ctheta)%7D%7B%5Cpartial%20%5Ctheta_%7Bi%7D%7D%20-%20%5Cleft(%20%5Cfrac%7B%5Cpartial%20U(%5Ctheta)%7D%7B%5Cpartial%20%5Ctheta_%7Bi%7D%7D%20+%20%5Cleft(%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon%20%5Cright)_%7Bi%7D%20%5Cright)%20%5C,%20(M%5E%7B-1%7D%20%5Crho)_%7Bi%7D,%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cleft%5B%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon%20%5Cright%5D%5E%7B%5Ctop%7D%20M%5E%7B-1%7D%20%5Crho.%0A%5Cend%7Baligned%7D%0A"></p>
<p>When using a larger mini-batch size: <img src="https://latex.codecogs.com/png.latex?%5Ctilde%7B%5Cmathcal%7BD%7D%7D%20%5Cto%20%5Cmathcal%7BD%7D">, the variance <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)"> is smaller: <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)%20%5Cto%200">, resulting in <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Coperatorname%7Bd%7D%20H%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cto%200."> At the limit, the total energy <img src="https://latex.codecogs.com/png.latex?H(%5Ctheta,%20%5Crho)"> is preserved, which is the <em>full-batch</em> Hamiltonian Monte Carlo mentioned above.</p>
<p>When using a much smaller mini-batch size: <img src="https://latex.codecogs.com/png.latex?%7C%5Ctilde%7B%5Cmathcal%7BD%7D%7D%7C%20%5Cll%20%7C%5Cmathcal%7BD%7D%7C">, the noise induced by the mini-batch, <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)">, is large (e.g., in terms of matrix norm), resulting in <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Coperatorname%7Bd%7D%20H%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cneq%200."> Consequently, the Hamiltonian is no longer invariant.</p>
<p>To correct the error due to the effect of mini-batches, one needs to perform one Metropolis - Hastings step to either reject or accept the new state. Either running a short or long simulation (corresponding to a small or large <img src="https://latex.codecogs.com/png.latex?%5Ctau"> in Algorithm Hamiltonian Monte Carlo), the cost of a Metropolis - Hastings step is still extremely large and wasteful if the sample is rejected. One workaround solution is to run a Metropolis - Hastings step on a subset of data instead of the entire dataset <span class="citation" data-cites="korattikara2014austerity bardenet2014towards">(Korattikara et al. 2014; Bardenet et al. 2014)</span>. There are, of course, some tradeoffs using such approaches.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Hockey puck on ice surface with random wind
</div>
</div>
<div class="callout-body-container callout-body">
<p>To continue with the same analogy of a hockey puck, the environment is now different with random wind blowing over the ice surface. That random wind may push the hockey puck further away in some random direction.</p>
</div>
</div>
<p>Indeed, the joint distribution <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D(%5Ctheta,%20%5Crho)"> can be determined to be stationary or not by analysing the corresponding Fokker - Planck equation as shown in the Appendix about the stationary of stochastic gradient due to mini-batches. In this case, <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D(%5Ctheta,%20%5Crho)"> is proved to be non-stationary.</p>
<p>In <span class="citation" data-cites="chaudhari2018stochastic">(Chaudhari and Soatto 2018)</span>, the joint distribution <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D(%5Ctheta,%20%5Crho)%20%5Cpropto%20%5Cexp(-H(%5Ctheta,%20%5Crho))"> in Equation&nbsp;5 is assumed to be stationary under the stochastic dynamics in Equation&nbsp;8. This is equivalent to proving that the left hand side term in the Fokker - Planck equation is zero: <img src="https://latex.codecogs.com/png.latex?%5Cfrac%7B%5Cpartial%20p_%7BH%7D(%5Ctheta,%20%5Crho)%7D%7B%5Cpartial%20t%7D%20=%200">. The authors then analyse and show that <em>the stationary distribution does not converge to the desired posterior distribution in general</em> <span class="citation" data-cites="chaudhari2018stochastic">(Chaudhari and Soatto 2018)</span>. This is, however, only true if the stationary distribution exists. And in this case, we prove that it does not (the distribution is non-stationary as shown in Section stationary of stochastic gradient due to mini-batches).</p>
</section>
<section id="stochastic-gradient-hamiltonian-monte-carlo-with-friction" class="level3" data-number="4.2">
<h3 data-number="4.2" class="anchored" data-anchor-id="stochastic-gradient-hamiltonian-monte-carlo-with-friction"><span class="header-section-number">4.2</span> Stochastic gradient Hamiltonian Monte Carlo with “friction”</h3>
<p>One way to overcome the stochastic estimation for the gradient of the potential energy, <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Ctheta%7D%20%5Ctilde%7BU%7D(%5Ctheta)">, is to introduce a “friction” term to the momentum update: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Bdcases%7D%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Ctheta%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20&amp;%20=%20%20M%5E%7B-1%7D%20%5Crho%20%5C%5C%0A%20%20%20%20&amp;%20%5C%5C%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Crho%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20&amp;%20=%20-%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%5Ctextcolor%7BCrimson%7D%7B-%20F%20M%5E%7B-1%7D%20%5Crho%7D%20+%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon,%20%5Cquad%20%5Cepsilon%20%5Csim%20%5Cmathcal%7BN%7D(0,%20I),%0A%5Cend%7Bdcases%7D%0A"> where: <img src="https://latex.codecogs.com/png.latex?F%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BN_%7B%5Crho%7D%20%5Ctimes%20N_%7B%5Crho%7D%7D"> denotes friction coefficient matrix. One requirement for <img src="https://latex.codecogs.com/png.latex?F"> is that: <img src="https://latex.codecogs.com/png.latex?F%20%5Csucceq%20%5Csqrt%7BV%7D"> (see the section on stationary SGD with injected noise for further details).</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Hockey puck on a friction surface with random wind
</div>
</div>
<div class="callout-body-container callout-body">
<p>To continue with the same analogy, the hockey puck is now sliding not on a frictionless ice surface, but a street surface which induces friction from the asphalt. There is still a random wind blowing. However, the friction of the surface prevents the hockey puck from moving too far away than the position it is expected.</p>
</div>
</div>
<p>In this case, one can prove that the joint distribution <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D(%5Ctheta,%20%5Crho)"> is stationary.</p>
<p>To link this sampling to the stochastic gradient descent, one can sample <img src="https://latex.codecogs.com/png.latex?%5Crho(t)%20%5Csim%20%5Cmathcal%7BN%7D(0,%20M)"> and apply one leapfrog step as follows: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Bdcases%7D%0A%20%20%20%20%5Crho%5Cleft(%20t%20+%20%5Cfrac%7B1%7D%7B2%7D%20%5Cright)%20&amp;%20=%20%5Crho(t)%20+%20%5Cfrac%7B%5Calpha%7D%7B2%7D%20%5Cleft%5B%20-%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%5Ctextcolor%7BCrimson%7D%7B-%20F%20M%5E%7B-1%7D%20%5Crho%7D%20+%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%5Ctheta%20(t%20+%201)%20&amp;%20=%20%5Ctheta(t)%20+%20%5Calpha%20M%5E%7B-1%7D%20%5Crho%5Cleft(%20t%20+%20%5Cfrac%7B1%7D%7B2%7D%20%5Cright).%0A%5Cend%7Bdcases%7D%0A"></p>
<p>It can be simplified by substituting <img src="https://latex.codecogs.com/png.latex?%5Crho(t%20+%20%5Cfrac%7B%5Calpha%7D%7B2%7D)"> into the expression of <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> to obtain: <img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20%5Ctheta(t%20+%201)%20=%20%5Ctheta(t)%20+%20%5Cfrac%7B%5Calpha%5E%7B2%7D%7D%7B2%7D%20M%5E%7B-1%7D%20%5Cleft%5B%20-%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%5Ctextcolor%7BCrimson%7D%7B-%20F%20M%5E%7B-1%7D%20%5Crho%7D%20+%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon%20%5Cright%5D%20+%20%5Calpha%20M%5E%7B-1%7D%20%5Crho(t),%0A%7D%0A"> which has a similar form as the <em>Stochastic Gradient Langevin Dynamics</em> <span class="citation" data-cites="welling2011bayesian">(Welling and Teh 2011)</span>.</p>
</section>
</section>
<section id="conclusion" class="level2" data-number="5">
<h2 data-number="5" class="anchored" data-anchor-id="conclusion"><span class="header-section-number">5</span> Conclusion</h2>
<p>This post reviews some seminar studies in <em>stochastic gradient</em> and <em>Monte Carlo sampling</em>. There have been many successive studies that explored and extended further. Of course, they have mostly developed on top of these studies and achieved better performance. However, it is important to understand the basic before moving to advance. Hopefully, this post would be found useful in one or another way.</p>
</section>





<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a><div id="quarto-appendix" class="default"><section id="appendices" class="level2 appendix unnumbered"><h2 class="anchored quarto-appendix-heading">Appendices</h2><div class="quarto-appendix-contents">

</div></section><section id="fokker---planck-equation" class="level2 appendix" data-number="6"><h2 class="anchored quarto-appendix-heading"><span class="header-section-number">6</span> Fokker - Planck equation</h2><div class="quarto-appendix-contents">

<p>The Fokker - Planck equation is used to analyse the evolution of the distribution of the variables in stochastic differential equation: <span id="eq-sde"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Coperatorname%7Bd%7D%20x(t)%20=%20-%20%5Cnabla%20f(x)%20%5Coperatorname%7Bd%7D%20t%20+%20%5Csqrt%7B2%20%5Ctau%20V(x)%7D%20%5Coperatorname%7Bd%7D%20W(t),%0A%5Ctag%7B9%7D"></span> where <img src="https://latex.codecogs.com/png.latex?f(x)"> is some function (e.g., loss function), <img src="https://latex.codecogs.com/png.latex?V(x)"> is a diffusion matrix and <img src="https://latex.codecogs.com/png.latex?W(t)"> is the Brownian motion, and <img src="https://latex.codecogs.com/png.latex?%5Ctau"> is a temperature.</p>
<div id="lem-time-variant-distribution" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 1</strong></span> The distribution <img src="https://latex.codecogs.com/png.latex?p(x)%20%5Cpropto%20%5Cexp%5Cleft(%20-H(x)%20%5Cright)"> of the variable <img src="https://latex.codecogs.com/png.latex?x"> in Equation&nbsp;9 evolves following the Fokker - Planck equation: <span id="eq-fokker-planck-equation"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20p(x)%7D%7B%5Cpartial%20t%7D%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20%5Cnabla%20f(x)%20p(x)%20+%20%5Ctau%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20V(x)%20p(x)%20%5Cright%5D%20%5Cright%5D,%0A%5Ctag%7B10%7D"></span> where: <img src="https://latex.codecogs.com/png.latex?%5Cnabla%20%5Ccdot"> denotes the divergence, and the divergence operator is applied column-wise to matrices.</p>
</div>
<p>Thus, one can prove that the distribution of the solution in the stochastic equation Equation&nbsp;9 is invariant by simply proving that <img src="https://latex.codecogs.com/png.latex?%5Cpartial%20p(x)/%5Cpartial%20t%20=%200">.</p>
</div></section><section id="stationary-distribution-of-parameters-obtained-from-sgd" class="level2 appendix" data-number="7"><h2 class="anchored quarto-appendix-heading"><span class="header-section-number">7</span> Stationary distribution of parameters obtained from SGD</h2><div class="quarto-appendix-contents">

<p>The main focus of this section is to investigate the stationary distribution <img src="https://latex.codecogs.com/png.latex?p(%5Ctheta,%20%5Crho)"> obtained through the stochastic gradient Hamiltonian Monte Carlo. Two types of noises are considered: <em>(i)</em> noise due to mini-batch effect and <em>(ii)</em> injected noise as in <span class="citation" data-cites="welling2011bayesian">(Welling and Teh 2011)</span>. The main tool is the Fokker - Planck equation presented in the section about the Fokker - Planck equation. To use the Fokker - Planck equation, the two variables of interest are coupled into a single vector: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20z%20=%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%20%5Ctheta%20&amp;%20%5Crho%0A%20%20%20%20%5Cend%7Bbmatrix%7D%5E%7B%5Ctop%7D.%0A"></p>


</div></section><section id="stochastic-gradient-with-mini-batches" class="level3 appendix" data-number="7.1"><h2 class="anchored quarto-appendix-heading"><span class="header-section-number">7.1</span> Stochastic gradient with mini-batches</h2><div class="quarto-appendix-contents">

<p>The dynamics in Equation&nbsp;8 can be rewritten as: <span id="eq-stochastic-hamiltonian-mc-naive"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20z%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20=%20%5Cfrac%7B%5Coperatorname%7Bd%7D%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%20%5Ctheta%20%5C%5C%0A%20%20%20%20%20%20%20%20%5Crho%0A%20%20%20%20%5Cend%7Bbmatrix%7D%20=%20-%20%5Cunderbrace%7B%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%200%20&amp;%20-I%20%5C%5C%0A%20%20%20%20%20%20%20%20I%20&amp;%200%0A%20%20%20%20%5Cend%7Bbmatrix%7D%7D_%7BG%7D%20%5Cunderbrace%7B%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%5C%5C%0A%20%20%20%20%20%20%20%20M%5E%7B-1%7D%20%5Crho%0A%20%20%20%20%5Cend%7Bbmatrix%7D%7D_%7B%5Cnabla%20H(z)%7D%20+%20%5Cunderbrace%7B%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%200%20&amp;%200%20%5C%5C%0A%20%20%20%20%20%20%20%200%20&amp;%20%5Csqrt%7BV(%5Ctheta)%7D%0A%20%20%20%20%5Cend%7Bbmatrix%7D%7D_%7BD(z)%7D%20%5Cunderbrace%7B%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%200%20%5C%5C%0A%20%20%20%20%20%20%20%20%5Cepsilon%0A%20%20%20%20%5Cend%7Bbmatrix%7D%7D_%7B%5Cepsilon%5E%7B%5Cprime%7D%7D,%0A%5Ctag%7B11%7D"></span> where: <img src="https://latex.codecogs.com/png.latex?%5Cepsilon%20%5Csim%20%5Cmathcal%7BN%7D(0,%20I)">.</p>
<p>The corresponding Fokker - Planck equation can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20p(z)%7D%7B%5Cpartial%20t%7D%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20G%20%5C,%20%5Cnabla%20H(z)%20%5C,%20p(z)%20+%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20D(z)%20p(z)%20%5Cright%5D%20%5Cright%5D.%0A"></p>
<p>Note that: <img src="https://latex.codecogs.com/png.latex?p(z)%20=%20%5Cexp%5Cleft(%20-H(z)%20%5Cright)%20/%20Z"> (assuming the temperature: <img src="https://latex.codecogs.com/png.latex?%5Ctau%20=%201">), then <img src="https://latex.codecogs.com/png.latex?H(z)%20=%20-%20%5Cln%20p(z)%20-%20%5Cln%20Z">. Thus, we can rewrite the Fokker - Planck equation as follows: <span id="eq-fokker-sgd-minibatch"><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20p(z)%7D%7B%5Cpartial%20t%7D%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20G%20%5C,%20%5Cnabla%20%5Cleft%5B%20-%5Cln%20p(z)%20%5Cright%5D%20%5C,%20p(z)%20+%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20D(z)%20p(z)%20%5Cright%5D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20-%20G%20%5C,%20%5Cnabla%20p(z)%20+%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20D(z)%20p(z)%20%5Cright%5D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20-%20G%20%5C,%20%5Cnabla%20p(z)%20%5Cright%5D%20+%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20D(z)%20p(z)%20%5Cright%5D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20D(z)%20p(z)%20%5Cright%5D%20%5Cright%5D.%0A%5Cend%7Baligned%7D%0A%5Ctag%7B12%7D"></span></p>
<p>For the last equality, we use the fact that: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20G%20%5C,%20%5Cnabla%20p(z)%20%5Cright%5D%20=%20-%5Cfrac%7B%5Cpartial%5E%7B2%7D%20p(%5Ctheta,%20%5Crho)%7D%7B%5Cpartial%20%5Ctheta%20%5C,%20%5Cpartial%20%5Crho%7D%20+%20%5Cfrac%7B%5Cpartial%5E%7B2%7D%20p(%5Ctheta,%20%5Crho)%7D%7B%5Cpartial%20%5Ctheta%20%5C,%20%5Cpartial%20%5Crho%7D%20=%200.%0A"></p>
<p>The result in Equation&nbsp;12 does not guarantee that <img src="https://latex.codecogs.com/png.latex?%5Cpartial%20p(%5Ctheta,%20%5Crho)%20/%20%5Cpartial%20t%20=%200."> In other words, there is not enough evidence to prove that <img src="https://latex.codecogs.com/png.latex?p(%5Ctheta,%20%5Crho)"> is stationary.</p>
<p>In practice, when we perform SGD, the covariance matrix <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)"> becomes smaller and smaller. In such case, we can assume that <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)%20%5Capprox%200">, and hence, the distribution <img src="https://latex.codecogs.com/png.latex?p(%5Ctheta,%20%5Crho)"> is stationary.</p>
</div></section><section id="stochastic-gradient-with-friction" class="level3 appendix" data-number="7.2"><h2 class="anchored quarto-appendix-heading"><span class="header-section-number">7.2</span> Stochastic gradient with friction</h2><div class="quarto-appendix-contents">



</div></section><section id="known-covariance-matrix" class="level4 appendix" data-number="7.2.1"><h2 class="anchored quarto-appendix-heading"><span class="header-section-number">7.2.1</span> Known covariance matrix </h2><div class="quarto-appendix-contents">

<p>According to <span class="citation" data-cites="chen2014stochastic">(Chen et al. 2014)</span>, if the covariance matrix <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)"> induced by the mini-batch effect is known, then one can introduce a friction force to the system as follows: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Bdcases%7D%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Ctheta%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20&amp;%20=%20%20M%5E%7B-1%7D%20%5Crho%20%5C%5C%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20%5Crho%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20&amp;%20=%20-%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%5Ctextcolor%7BCrimson%7D%7B-%20%5Csqrt%7BV(%5Ctheta)%7D%20M%5E%7B-1%7D%20%5Crho%7D%20+%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Cepsilon,%20%5Cquad%20%5Cepsilon%20%5Csim%20%5Cmathcal%7BN%7D(0,%20I).%0A%5Cend%7Bdcases%7D%0A"></p>
<p>This can be rewritten in the form of vectors and matrices as follows: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%20%5Ctheta%20%5C%5C%0A%20%20%20%20%20%20%20%20%5Crho%0A%20%20%20%20%5Cend%7Bbmatrix%7D%20=%20-%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%200%20&amp;%20-I%20%5C%5C%0A%20%20%20%20%20%20%20%20I%20&amp;%20%5Csqrt%7BV(%5Ctheta)%7D%0A%20%20%20%20%5Cend%7Bbmatrix%7D%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%20%5Cnabla_%7B%5Ctheta%7D%20U(%5Ctheta)%20%5C%5C%0A%20%20%20%20%20%20%20%20M%5E%7B-1%7D%20%5Crho%0A%20%20%20%20%5Cend%7Bbmatrix%7D%20+%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%200%20&amp;%200%20%5C%5C%0A%20%20%20%20%20%20%20%200%20&amp;%20%5Csqrt%7BV(%5Ctheta)%7D%0A%20%20%20%20%5Cend%7Bbmatrix%7D%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%20%20%20%200%20%5C%5C%0A%20%20%20%20%20%20%20%20%5Cepsilon%0A%20%20%20%20%5Cend%7Bbmatrix%7D.%0A"></p>
<p>Following the notations defined in Equation&nbsp;11, the system dynamics can be rewritten as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Coperatorname%7Bd%7D%20z%7D%7B%5Coperatorname%7Bd%7D%20t%7D%20=%20-%20%5Cleft%5B%20G%20+%20D(z)%20%5Cright%5D%20%5Cnabla%20H(z)%20+%20D(z)%20%5Cepsilon%5E%7B%5Cprime%7D.%0A"></p>
<p>The corresponding Fokker - Planck equation is then written as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20p(z,%20t)%7D%7B%5Cpartial%20t%7D%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20%5Cleft%5B%20G%20+%20D(z)%20%5Cright%5D%20%5Cnabla%20H(z)%20%5C,%20p(z)%20+%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20D(z)%20%5C,%20p(z)%20%5Cright%5D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20-%20D(z)%20%5Cnabla%20p(z)%20+%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20D(z)%20%5C,%20p(z)%20%5Cright%5D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20-%20D(z)%20%5Cnabla%20p(z)%20+%20D(z)%20%5Cnabla%20p(z)%20+%20p(z)%20%5Cnabla%20%5Ccdot%20D(z)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cnabla%20%5Ccdot%20%5Cleft%5B%20p(z)%20%5Cnabla%20%5Ccdot%20D(z)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%200.%0A%5Cend%7Baligned%7D%0A"></p>
<p>The third equality is due to the Identity 1.11.16 in <a href="https://pkel015.connect.amazon.auckland.ac.nz/SolidMechanicsBooks/Part_III/Chapter_1_Vectors_Tensors/Vectors_Tensors_14_Tensor_Calculus.pdf">Tensor calculus note</a>.</p>
<p>The last equality holds due to the fact that <img src="https://latex.codecogs.com/png.latex?%5Cnabla%20%5Ccdot%20D(z)%20=%200">. This can easily be proved by using the definition of <em>divergence</em> <img src="https://latex.codecogs.com/png.latex?%5Cnabla%20%5Ccdot"> and the structure of <img src="https://latex.codecogs.com/png.latex?D(z)"> (noise is added to <img src="https://latex.codecogs.com/png.latex?%5Crho"> although it depends on <img src="https://latex.codecogs.com/png.latex?%5Ctheta">).</p>
<p>In summary, injecting a noise corresponding to a friction force <img src="https://latex.codecogs.com/png.latex?%5Ctextcolor%7BBrickRed%7D%7B-%20%5Csqrt%7BV(%5Ctheta)%7D%20M%5E%7B-1%7D%20%5Crho%7D"> results in a stationary distribution <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D(%5Ctheta,%20%5Crho)">.</p>
</div></section><section id="practical-stochastic-gradient-hamiltonian-monte-carlo-with-unknown-covariance-matrix" class="level4 appendix" data-number="7.2.2"><h2 class="anchored quarto-appendix-heading"><span class="header-section-number">7.2.2</span> Practical stochastic gradient Hamiltonian Monte Carlo with unknown covariance matrix</h2><div class="quarto-appendix-contents">

<p>In practice, we might not know the covariance matrix <img src="https://latex.codecogs.com/png.latex?V(%5Ctheta)">. In such a situation, one might introduce a friction matrix <img src="https://latex.codecogs.com/png.latex?F"> that satisfies: <img src="https://latex.codecogs.com/png.latex?F%20%5Csucceq%20%5Csqrt%7BV(%5Ctheta)%7D">. In other words, <img src="https://latex.codecogs.com/png.latex?F%20-%20%5Csqrt%7BV(%5Ctheta)%7D%20%5Csucceq%200"> is positive definite. In this case, the system is over-damped and the total energy <img src="https://latex.codecogs.com/png.latex?H(%5Ctheta,%20%5Crho)"> will gradually decrease to 0.</p>
<div class="callout callout-style-default callout-note callout-titled">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
Note
</div>
</div>
<div class="callout-body-container callout-body">
<p>In certain situations, one can prove that the stochastic gradient Hamiltonian Carlo results in a stationary distribution <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D(%5Ctheta,%20%5Crho)">, it does not mean that <img src="https://latex.codecogs.com/png.latex?p_%7BH%7D(%5Ctheta,%20%5Crho)"> is the true posterior of interest (the one without any noise).</p>
</div>
</div>
</div></section><section id="references" class="level2 appendix unnumbered"><h2 class="anchored quarto-appendix-heading">References</h2><div class="quarto-appendix-contents">

<div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-bardenet2014towards" class="csl-entry">
Bardenet, Rémi, Arnaud Doucet, and Chris Holmes. 2014. <span>“Towards Scaling up <span>M</span>arkov Chain <span>M</span>onte <span>C</span>arlo: An Adaptive Subsampling Approach.”</span> <em>International Conference on Machine Learning</em>, 405–13.
</div>
<div id="ref-chaudhari2018stochastic" class="csl-entry">
Chaudhari, Pratik, and Stefano Soatto. 2018. <span>“Stochastic Gradient Descent Performs Variational Inference, Converges to Limit Cycles for Deep Networks.”</span> <em>International Conference on Learning Representations</em>.
</div>
<div id="ref-chen2014stochastic" class="csl-entry">
Chen, Tianqi, Emily Fox, and Carlos Guestrin. 2014. <span>“Stochastic Gradient <span>H</span>amiltonian <span>M</span>onte <span>C</span>arlo.”</span> <em>International Conference on Machine Learning</em>, 1683–91.
</div>
<div id="ref-korattikara2014austerity" class="csl-entry">
Korattikara, Anoop, Yutian Chen, and Max Welling. 2014. <span>“Austerity in <span>MCMC</span> Land: <span>C</span>utting the <span>M</span>etropolis - <span>H</span>astings Budget.”</span> <em>International Conference on Machine Learning</em>, 181–89.
</div>
<div id="ref-mackay2003information" class="csl-entry">
MacKay, David JC. 2003. <em>Information Theory, Inference and Learning Algorithms</em>. Cambridge university press.
</div>
<div id="ref-welling2011bayesian" class="csl-entry">
Welling, Max, and Yee W Teh. 2011. <span>“<span>B</span>ayesian Learning via Stochastic Gradient <span>L</span>angevin Dynamics.”</span> <em>International Conference on Machine Learning</em>, 681–88.
</div>
</div>


<!-- -->

</div></section><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by/4.0/">CC BY 4.0</a></div></div></section><section class="quarto-appendix-contents" id="quarto-citation"><h2 class="anchored quarto-appendix-heading">Citation</h2><div><div class="quarto-appendix-secondary-label">BibTeX citation:</div><pre class="sourceCode code-with-copy quarto-appendix-bibtex"><code class="sourceCode bibtex">@online{nguyen2023,
  author = {Nguyen, Cuong},
  title = {Stochastic Gradient and {Hamiltonian} {Monte} {Carlo}},
  date = {2023-11-19},
  url = {https://cnguyen10.github.io/posts/stochastic_grad_hamiltonian_monte_carlo/},
  langid = {en}
}
</code></pre><div class="quarto-appendix-secondary-label">For attribution, please cite this work as:</div><div id="ref-nguyen2023" class="csl-entry quarto-appendix-citeas">
Nguyen, Cuong. 2023. <span>“Stochastic Gradient and Hamiltonian Monte
Carlo.”</span> November 19. <a href="https://cnguyen10.github.io/posts/stochastic_grad_hamiltonian_monte_carlo/">https://cnguyen10.github.io/posts/stochastic_grad_hamiltonian_monte_carlo/</a>.
</div></div></section></div> ]]></description>
  <category>Bayesian Inference</category>
  <category>MCMC</category>
  <category>Statistics</category>
  <guid>https://cnguyen10.github.io/posts/stochastic_grad_hamiltonian_monte_carlo/</guid>
  <pubDate>Sun, 19 Nov 2023 00:00:00 GMT</pubDate>
</item>
<item>
  <title>Expectation - Maximisation algorithm and its applications in finite mixture models</title>
  <dc:creator>Cuong Nguyen</dc:creator>
  <link>https://cnguyen10.github.io/posts/mixture-models/</link>
  <description><![CDATA[ 




<p>Missing data and latent variables are frequently encountered in various machine learning and statistical inference applications. A common example is the finite mixture model, which includes Gaussian mixture and multinomial mixture models. Due to the inherent nature of missing data or latent variables, calculating the likelihood of these models requires marginalisation over the latent variable distribution. This, in turn, complicates the process of maximum likelihood estimation (MLE).</p>
<p>The expectation-maximisation (EM) algorithm, introduced in <span class="citation" data-cites="dempster1977maximum">(Dempster et al. 1977)</span>, offers a general technique for handling latent variable models. The fundamental concept behind the EM algorithm is to iterate between two steps: the E-step (expectation step) and the M-step (maximisation step). In the E-step, the posterior distribution of the latent variables (or missing data) is estimated. This estimated information is then used in the M-step to compute the MLE as if the data were complete. It has been proven that this iterative process guarantees a non-decreasing likelihood function. In simpler terms, the EM algorithm converges to a saddle point.</p>
<p>While the EM algorithm is a powerful tool, this explanation may not be as clear as desired. Consequently, this post aims to provide a more accessible explanation of the EM algorithm. Additionally, some readers may question the choice of EM over stochastic gradient descent (SGD), a prevalent optimisation method. This post will, therefore, explore the key differences between these two approaches. Finally, the applications of the EM algorithm in the context of finite mixture modelling, specifically focusing on the MLE problems in Gaussian mixture models and multinomial mixture models, are also demonstrated.</p>
<section id="notations" class="level2" data-number="1">
<h2 data-number="1" class="anchored" data-anchor-id="notations"><span class="header-section-number">1</span> Notations</h2>
<p>Before diving into the explanation and formulation, it is important to define the notations used in this post as follows:</p>
<table class="table-striped table-hover caption-top table">
<caption>Notations used in the formulation of the EM algorithm.</caption>
<thead>
<tr class="header">
<th>Notation</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BD%7D"></td>
<td>observable data</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BK%7D"></td>
<td>latent variable or missing data</td>
</tr>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Ctheta%20%5Cin%20%5CTheta"></td>
<td>the parameter of interest in MLE</td>
</tr>
</tbody>
</table>
</section>
<section id="em-algorithm" class="level2" data-number="2">
<h2 data-number="2" class="anchored" data-anchor-id="em-algorithm"><span class="header-section-number">2</span> EM algorithm</h2>
<p>The formulation presented in this post follows a probabilistic approach. In probabilistic modelling, there are two processes: data generation (also known as a <em>forward</em> problem) and parameter inference (also known as an <em>inverse problem</em>).</p>
<section id="sec-data-generation" class="level3" data-number="2.1">
<h3 data-number="2.1" class="anchored" data-anchor-id="sec-data-generation"><span class="header-section-number">2.1</span> Data generation</h3>
<p>The data is generated as follows:</p>
<ul>
<li>draw the parameter <img src="https://latex.codecogs.com/png.latex?%5Cpi"> from its prior: <img src="https://latex.codecogs.com/png.latex?%5Cpi%20%5Csim%20%5CPr(%5Cpi)">,</li>
<li>draw the parameter <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> from its prior: <img src="https://latex.codecogs.com/png.latex?%5Ctheta%20%5Csim%20%5CPr(%5Ctheta)">,</li>
<li>draw a <em>hidden</em> sample <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D"> from a prior distribution: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D%20%5Csim%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cpi)">, and</li>
<li>draw an <em>observable</em> sample <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"> given <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D"> as follows: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D%20%5Csim%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cmathbf%7Bz%7D,%20%5Ctheta)">,</li>
</ul>
<p>where <img src="https://latex.codecogs.com/png.latex?%5Cpi"> and <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> are the parameter of the model of interest.</p>
<div class="callout callout-style-default callout-note callout-titled" title="Parameter $\pi$">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Parameter <img src="https://latex.codecogs.com/png.latex?%5Cpi">
</div>
</div>
<div class="callout-body-container callout-body">
<p>In many tutorials of EM, the parameter <img src="https://latex.codecogs.com/png.latex?%5Cpi"> of the prior of the latent variable <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D"> is often defined implicitly. In this post, it is defined explicitly to make the explanation easier to follow.</p>
</div>
</div>
<p>Such a data generation process is often visualised by the graphical model shown below</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">%%{
    init: {
        'theme': 'base',
        'themeVariables': {
            'primaryColor': '#ffffff'
        }
    }
}%%
flowchart LR
    subgraph data["data"]
        z((z)):::nonfilled--&gt;x((x)):::filled;
    end
    pi((π)):::nonfilled--&gt;z;
    theta((θ)):::nonfilled--&gt;x;

    linkStyle default stroke: black;
    classDef nonfilled fill: none;
    style data fill: none;
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="sec-parameter-inference" class="level3" data-number="2.2">
<h3 data-number="2.2" class="anchored" data-anchor-id="sec-parameter-inference"><span class="header-section-number">2.2</span> Parameter inference</h3>
<p>Given a set of observed i.i.d data <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D%20=%20%5C%7B%5Cmathbf%7Bx%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BN%7D">, the general objective is to infer the posterior <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cpi,%20%5Ctheta%20%7C%20%5Cmathbf%7Bx%7D)."> of the parameters <img src="https://latex.codecogs.com/png.latex?%5Cpi"> and <img src="https://latex.codecogs.com/png.latex?%5Ctheta">. Instead of inferring the exact posterior <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cpi,%20%5Ctheta%20%7C%20%5Cmathbf%7Bx%7D)">, which may be difficult in many cases, one can perform <em>point estimate</em>, such as MLE or maximise a posterior (MAP), which can be written as follows:</p>
<p><span id="eq-map"><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cmax_%7B%5Cpi,%20%5Ctheta%7D%20%5Cln%20%5CPr(%5Cpi,%20%5Ctheta%20%7C%20%5C%7B%5Cmathbf%7Bx%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BN%7D)%20&amp;%20=%20%5Cmax_%7B%5Cpi.%20%5Ctheta%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cunderbrace%7B%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cpi,%20%5Ctheta)%7D_%7B%5Ctext%7Bin-complete%20log-likelihood%7D%7D%20+%20%5Cln%20%5CPr(%5Cpi)%20+%20%5Cln%20%5CPr(%5Ctheta)%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cmax_%7B%5Cpi,%20%5Ctheta%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cln%20%5Cleft%5B%20%5Csum_%7B%5Cmathbf%7Bz%7D_%7Bi%7D%7D%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D,%20%5Cmathbf%7Bz%7D_%7Bi%7D%20%7C%20%5Cpi,%20%5Ctheta)%20%5Cright%5D%20+%20%5Cln%20%5CPr(%5Cpi)%20+%20%5Cln%20%5CPr(%5Ctheta).%0A%5Cend%7Baligned%7D%0A%5Ctag%7B1%7D"></span></p>
<p>Due to the presence of the sum over the latent variable <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D">, the <em>in-complete</em> log-likelihood may not be evaluated directly on the joint distribution (especially when <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D"> is continuous), making the optimisation difficult.</p>
<p>Fortunately, according to the data generation presented in Section&nbsp;2.1, the completed log-likelihood <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)"> can be evaluated easily:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%7C%20%5Cpi,%20%5Ctheta)%20=%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cmathbf%7Bz%7D,%20%5Ctheta)%20+%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cpi).%0A"></p>
<p>Such an assumption allows EM to get around the difficulty when evaluating the expression in Equation&nbsp;1.</p>
<div class="callout callout-style-default callout-tip callout-titled" title="Main idea behind EM">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Tip</span>Main idea behind EM
</div>
</div>
<div class="callout-body-container callout-body">
<ul>
<li>find a lower bound of the objective function in Equation&nbsp;1,</li>
<li>tighten the lower bound, and</li>
<li>maximise the tightest lower bound.</li>
</ul>
</div>
</div>
<p>The first two sub-steps combined are often known as the <em>Expectation</em> step (or E-step for short), while the last step is known as the <em>Maximisation</em> step (or M-step for short). These steps are then presented in the following sub-sub-sections.</p>
<section id="evidence-lower-bound-elbo" class="level4" data-number="2.2.1">
<h4 data-number="2.2.1" class="anchored" data-anchor-id="evidence-lower-bound-elbo"><span class="header-section-number">2.2.1</span> Evidence lower bound (ELBO)</h4>
<p>To find a lower bound of the objective function in Equation&nbsp;1, one can follow the <em>variational inference</em> approach to obtain the ELBO. In particular, let <img src="https://latex.codecogs.com/png.latex?q(%5Cmathbf%7Bz%7D)%20%3E%200"> be an arbitrary distribution of the latent variable <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D">. The in-complete log-likelihood in Equation&nbsp;1 can be re-written as follows: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20&amp;%20=%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20+%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20+%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20-%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20+%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20%5Cright%5D%20+%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%5Cleft%5B%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20%5Cright%5D%20+%20%5Coperatorname%7BKL%7D%20%5Cleft%5B%20q(%5Cmathbf%7Bz%7D)%20%5C%7C%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20%5Cright%5D,%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> where: <img src="https://latex.codecogs.com/png.latex?%5Coperatorname%7BKL%7D%5B%20q%20%5C%7C%20p%20%5D"> is the Kullback-Leibler divergence (KL divergence for short) between probability distributions <img src="https://latex.codecogs.com/png.latex?q"> and <img src="https://latex.codecogs.com/png.latex?p">.</p>
<p>Since <img src="https://latex.codecogs.com/png.latex?%5Coperatorname%7BKL%7D%5B%20q%20%5C%7C%20p%20%5D%20%5Cge%200"> and <img src="https://latex.codecogs.com/png.latex?%5Coperatorname%7BKL%7D%5B%20q%20%5C%7C%20p%20%5D%20=%200"> iff <img src="https://latex.codecogs.com/png.latex?q%20=%20p">, the log-likelihood of interest can be lower-bounded as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20%5Cge%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20%5Cright%5D,%0A"> and the equality occurs iff <img src="https://latex.codecogs.com/png.latex?q(%5Cmathbf%7Bz%7D)%20=%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)">, which is the posterior of the latent variable <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D"> after observing the data <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D">.</p>
</section>
<section id="tightening-the-elbo" class="level4" data-number="2.2.2">
<h4 data-number="2.2.2" class="anchored" data-anchor-id="tightening-the-elbo"><span class="header-section-number">2.2.2</span> Tightening the ELBO</h4>
<p>To obtain the tightest lower bound, one must perform the following optimisation:</p>
<p><span id="eq-e-step"><img src="https://latex.codecogs.com/png.latex?%0Aq%5E%7B*%7D%20=%20%5Coperatorname*%7Bargmax%7D_%7Bq%7D%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20%5Cright%5D.%0A%5Ctag%7B2%7D"></span></p>
<p>As mentioned above, the tightest bound is when <img src="https://latex.codecogs.com/png.latex?q%5E%7B*%7D(%5Cmathbf%7Bz%7D)%20=%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)">, or the “variational” posterior approaches the true posterior of the latent variable <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D">. Such a true posterior can be obtained in certain simple cases, but is intractable when the modelling becomes more complex. In those cases, only a local optima “variational” posterior <img src="https://latex.codecogs.com/png.latex?q(%5Cmathbf%7Bz%7D)"> is calculated <span class="citation" data-cites="bernardo2003variational">(<span class="nocase">Bernardo et al.</span> 2003)</span>.</p>
<div class="callout callout-style-default callout-note callout-titled" title="True posterior in the E-step">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>True posterior in the E-step
</div>
</div>
<div class="callout-body-container callout-body">
<p>Such an observation explains why in the vanilla EM, it is often stated that the E-step is to calculate the true posterior of the latent variable <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)">. The superscript <img src="https://latex.codecogs.com/png.latex?t"> denotes the parameters at the <img src="https://latex.codecogs.com/png.latex?t">-th iteration. This is to avoid taking them into account when maximising the completed-log-likelihood in the M-step. Instead of following that convention, <img src="https://latex.codecogs.com/png.latex?q%5E%7B*%7D"> is used to avoid the confusion.</p>
</div>
</div>
</section>
<section id="maximising-the-possibly-tightest-lower-bound" class="level4" data-number="2.2.3">
<h4 data-number="2.2.3" class="anchored" data-anchor-id="maximising-the-possibly-tightest-lower-bound"><span class="header-section-number">2.2.3</span> Maximising the possibly-tightest lower bound</h4>
<p>Finally, the possibly-tightest lower bound is then maximised with respect to the parameters <img src="https://latex.codecogs.com/png.latex?%5Cpi"> and <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> as follows:</p>
<p><span id="eq-m-step"><img src="https://latex.codecogs.com/png.latex?%0A%5Cpi%5E%7B(t%20+%201)%7D,%20%5Ctheta%5E%7B(t%20+%201)%7D%20%5Cgets%20%5Coperatorname*%7Bargmax%7D_%7B%5Cpi,%20%5Ctheta%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cmathbb%7BE%7D_%7Bq%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bi%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D,%20%5Cmathbf%7Bz%7D_%7Bi%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Ccancel%7B%5Cln%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D)%7D%20%5Cright%5D%20+%20%5Cln%20%5CPr(%5Cpi)%20+%20%5Cln%20%5CPr(%5Ctheta).%0A%5Ctag%7B3%7D"></span></p>
<p>In summary, instead of maximising the difficult-to-calculate objective function in Equation&nbsp;1, the EM algorithm is to execute the alternative optimisation written as follows:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cmax_%7B%5Cpi,%20%5Ctheta%7D%20%5Cmax_%7Bq_%7Bi%7D%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D_%7Bi%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D,%20%5Cmathbf%7Bz%7D_%7Bi%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20q(%5Cmathbf%7Bz%7D)%20%5Cright%5D%20+%20%5Cln%20%5CPr(%5Cpi)%20+%20%5Cln%20%5CPr(%5Ctheta).%0A"></p>
<p>The whole EM algorithm can be referred to Algorithm 1.</p>
<div id="algo-em" class="pseudocode-container quarto-float" data-caption-prefix="Algorithm" data-line-number="true" data-line-number-punc=":" data-comment-delimiter="//" data-indent-size="1.2em" data-pseudocode-number="1" data-no-end="false">
<div class="pseudocode">
\begin{algorithm} \caption{Expectation - Maximisation algorithm} \begin{algorithmic} \Procedure{EM}{$\mathbf{x}$} \State initialise mixture coefficient $\pi$ \State initialise $\theta$ \While{not converged} \State calculate the ELBO: $Q \gets \operatorname{E-step}(\mathbf{x}, \pi, \theta)$ \State maximise the ELBO: $\pi, \theta \gets \operatorname{M-step}(Q, \pi, \theta)$ \EndWhile \State return $\pi, \theta$ \EndProcedure \end{algorithmic} \end{algorithm}
</div>
</div>
</section>
</section>
<section id="convergence-of-the-em-algorithm" class="level3" data-number="2.3">
<h3 data-number="2.3" class="anchored" data-anchor-id="convergence-of-the-em-algorithm"><span class="header-section-number">2.3</span> Convergence of the EM algorithm</h3>
<p>The following theorem proves that the EM algorithm improves the lower-bound after every iteration. For simplicity, the priors <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cpi)"> and <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Ctheta)"> are ignored from the proof below, but extending to include these prior terms is trivial.</p>
<div id="thm-convergence" class="theorem">
<p><span class="theorem-title"><strong>Theorem 1</strong></span> Assume that <img src="https://latex.codecogs.com/png.latex?q%5E%7B*%7D(%5Cmathbf%7Bz%7D)%20=%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)">, then after each EM iteration, the log-likelihood <img src="https://latex.codecogs.com/png.latex?%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)"> is non-decreasing. Mathematically, it can be written as follows: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi%5E%7B(t%20+%201)%7D,%20%5Ctheta%5E%7B(t%20+%201)%7D)%20%5Cge%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D),%0A"> where the superscript denotes the result obtained after that iteration.</p>
</div>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The log-likelihood of interest can be written as: <span id="eq-likelihood_theta"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20%5Cright%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A%5Ctag%7B4%7D"></span></p>
<p>Since it holds for any <img src="https://latex.codecogs.com/png.latex?(%5Cpi,%20%5Ctheta)">, substituting <img src="https://latex.codecogs.com/png.latex?%5Cpi%20=%20%5Cpi%5E%7B(t)%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Ctheta%20=%20%5Ctheta%5E%7B(t)%7D"> gives: <span id="eq-likelihood_after_iteration_nth"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20%5Cright%5D.%0A%5Ctag%7B5%7D"></span></p>
<p>Substracting side by side of Equation&nbsp;4 and Equation&nbsp;5 gives the following: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20%5Cright.%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20%5Cleft.%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20+%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20%5Cright.%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20%5Cleft.%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20%5Cright%5D%20+%20%5Coperatorname%7BKL%7D%20%5Cleft%5B%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20%5C%7C%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Ctheta)%20%5Cright%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>Since KL divergence is non-negative, one can imply that: <span id="eq-likelihood_difference"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20%5Cge%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20-%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20%5Cright%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A%5Ctag%7B6%7D"></span></p>
<p>In the M-step, the parameters <img src="https://latex.codecogs.com/png.latex?(%5Cpi%5E%7B(t%20+%201)%7D,%20%5Ctheta%5E%7B(t%20+%201)%7D)"> are obtained by maximising the first term in the right hand side: <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi,%20%5Ctheta)%20%5Cright%5D"> w.r.t. <img src="https://latex.codecogs.com/png.latex?(%5Cpi,%20%5Ctheta)">. Thus, according to the definition of the maximisation: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi%5E%7B(t%20+%201)%7D,%20%5Ctheta%5E%7B(t%20+%201)%7D)%20%5Cright%5D%20%5Cge%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%7D%20%5Cleft%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7Bz%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D)%20%5Cright%5D.%0A"></p>
<p>Hence, one can conclude that: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi%5E%7B(t%20+%201)%7D,%20%5Ctheta%5E%7B(t%20+%201)%7D)%20%5Cge%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi%5E%7B(t)%7D,%20%5Ctheta%5E%7B(t)%7D).%0A"></p>
</div>
</section>
</section>
<section id="applications-of-em-in-finite-mixture-models" class="level2" data-number="3">
<h2 data-number="3" class="anchored" data-anchor-id="applications-of-em-in-finite-mixture-models"><span class="header-section-number">3</span> Applications of EM in finite mixture models</h2>
<p>One of the typical applications of EM algorithm is to perform maximum likelihood for finite mixture models. This section is, therefore, dedicated to discuss the application of EM on Gaussian and multinomial mixture models.</p>
<section id="gaussian-mixture-models" class="level3" data-number="3.1">
<h3 data-number="3.1" class="anchored" data-anchor-id="gaussian-mixture-models"><span class="header-section-number">3.1</span> Gaussian mixture models</h3>
<p>The Gaussian mixture distribution can be written as a <em>convex</em> combination of <img src="https://latex.codecogs.com/png.latex?K"> Gaussian components: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20%5Cmu,%20%5CSigma)%20=%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cpi_%7Bk%7D%20%5C,%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D;%20%5Cmu_%7Bk%7D,%20%5CSigma_%7Bk%7D),%0A"> where: <img src="https://latex.codecogs.com/png.latex?%5Cpi_%7Bk%7D%20%5Cin%20%5B0,%201%5D"> and <img src="https://latex.codecogs.com/png.latex?%5Cpmb%7B%5Cpi%7D%5E%7B%5Ctop%7D%20%5Cpmb%7B1%7D%20=%201">.</p>
<section id="data-generation" class="level4" data-number="3.1.1">
<h4 data-number="3.1.1" class="anchored" data-anchor-id="data-generation"><span class="header-section-number">3.1.1</span> Data generation</h4>
<p>A data-point of the above Gaussian mixture distribution can be generated as follows:</p>
<ul>
<li>sample a probability <img src="https://latex.codecogs.com/png.latex?%5Cpi"> from a Dirichlet prior: <img src="https://latex.codecogs.com/png.latex?%5Cpi%20%5Csim%20%5CPr(%5Cpi%20%7C%20%5Calpha)%20=%20%5Coperatorname%7BDir%7D(%5Cpi%20%7C%20%5Calpha)">,</li>
<li>sample <img src="https://latex.codecogs.com/png.latex?K"> sets of parameters <img src="https://latex.codecogs.com/png.latex?(%5Cmu_%7Bk%7D,%20%5CSigma_%7Bk%7D)"> from an normal-inverse-Wishart prior: <img src="https://latex.codecogs.com/png.latex?(%5Cmu_%7Bk%7D,%20%5CSigma_%7Bk%7D)%20%5Csim%20%5CPr(%5Cmu,%20%5CSigma%20%7C%20m,%20%5Clambda,%20%5CPsi,%20%5Cnu)%20=%20%5Coperatorname%7BNIW%7D(%5Cmu,%20%5CSigma%20%7C%20m,%20%5Clambda,%20%5CPsi,%20%5Cnu)">,</li>
<li>sample the index of a Gaussian component: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D%20%5Csim%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cpi)%20=%20%5Coperatorname%7BCategorical%7D(%5Cmathbf%7Bz%7D%20%7C%20%5Cpmb%7B%5Cpi%7D)">, then</li>
<li>sample a data-point from the corresponding Gaussian component: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D%20%5Csim%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cmathbf%7Bz%7D,%20%5Cmu,%20%5CSigma)%20=%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D%7C%20%5Cmu_%7Bk%7D,%20%5CSigma_%7Bk%7D)">, where <img src="https://latex.codecogs.com/png.latex?z_%7Bk%7D%20=%201">.</li>
</ul>
<p>The data generation process can also be visualised in the graphical model shown below.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">%%{
    init: {
        'theme': 'base',
        'themeVariables': {
            'primaryColor': '#ffffff'
        }
    }
}%%
flowchart LR
    subgraph data["data"]
        direction LR
        z((z)):::rv --&gt; x((x)):::rv
    end
    alpha((α)):::notfilled --&gt; pi((π)):::params --&gt; z
    sigma((Σ)):::params --&gt; mu
    psi((Ψ)):::notfilled --&gt; sigma
    nu((ν)):::notfilled --&gt; sigma
    sigma --&gt; x
    mu0((m)):::notfilled --&gt; mu((μ)):::params --&gt; x
    lambda((λ)):::notfilled --&gt; mu

    style z fill: none
    classDef params stroke: #000, fill: none
    classDef rv stroke: #000
    classDef notfilled fill: none
    linkStyle default stroke: #000
    style data fill: none
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="objective" class="level4" data-number="3.1.2">
<h4 data-number="3.1.2" class="anchored" data-anchor-id="objective"><span class="header-section-number">3.1.2</span> Objective</h4>
<p>Given set of data-points <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Cmathbf%7Bx%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BN%7D"> sampled from the Gaussian mixture distribution, the aim is to infer the point estimate, and in particular MAP, of <img src="https://latex.codecogs.com/png.latex?(%5Cpi,%20%5Cmu,%20%5CSigma)">. Such an objective can be written as follows:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20&amp;%20%5Cmax_%7B%5Cpi,%20%5Cmu,%20%5CSigma%7D%20%5Cln%20%5CPr(%5Cpi,%20%5Cmu,%20%5CSigma%20%7C%20%5C%7B%5Cmathbf%7Bx%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BN%7D,%20%5Calpha,%20m,%20%5Clambda,%20%5CPsi,%20%5Cnu)%20%5C%5C%0A%20%20%20%20&amp;=%20%5Cmax_%7B%5Cpi,%20%5Cmu,%20%5CSigma%7D%20%5Cfrac%7B1%7D%7BN%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cpi,%20%5Cmu,%20%5CSigma)%20+%20%5Cln%20%5Coperatorname%7BDir%7D(%5Cpi%20%7C%20%5Calpha)%20+%20%5Cln%20%5Coperatorname%7BNIW%7D(%5Cmu,%20%5CSigma%20%7C%20m,%20%5Clambda,%20%5CPsi,%20%5Cnu).%0A%5Cend%7Baligned%7D%0A"></p>
</section>
<section id="parameter-inference" class="level4" data-number="3.1.3">
<h4 data-number="3.1.3" class="anchored" data-anchor-id="parameter-inference"><span class="header-section-number">3.1.3</span> Parameter inference</h4>
<p>In this case, one can simply follow the EM algorithm presented in Section Section&nbsp;2.2. Note that the likelihood on <img src="https://latex.codecogs.com/png.latex?N"> iid data-points can be written as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cprod_%7Bi%20=%201%7D%5E%7BN%7D%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cpi,%20%5Ctheta)%20=%20%5Cprod_%7Bi%20=%201%7D%5E%7BN%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201,%20%5Ctheta)%20%5C,%20%5CPr(z_%7Bik%7D%20=%201%20%7C%20%5Cpi).%0A"></p>
<p><strong>E-step:</strong> optimises the lower bound with respect to the “variational” posterior. As shown in Section&nbsp;2.2, <img src="https://latex.codecogs.com/png.latex?q%5E%7B*%7D%20=%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Cmu,%20%5CSigma)"> results in the tightest bound. Fortunately, in this case of Gaussian mixture models, the true posterior <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cmathbf%7Bx%7D,%20%5Cpi,%20%5Cmu,%20%5CSigma)"> can be calculated in closed-form as follows:</p>
<p><span id="eq-gmm_e_step"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cboxed%7B%0A%20%20%20%20%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%20%20%20%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20&amp;%20=%20%5CPr(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201%20%7C%20%5Cmathbf%7Bx%7D_%7Bi%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Cmu%5E%7B(t)%7D,%20%5CSigma%5E%7B(t)%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20%20%20%20%20&amp;%20=%20%5Cfrac%7B%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201,%20%5Cmu%5E%7B(t)%7D,%20%5CSigma%5E%7B(t)%7D)%20%5C,%20%5CPr(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201%20%7C%20%5Cpi%5E%7B(t)%7D)%7D%7B%5Csum_%7Bj%20=%201%7D%5E%7BK%7D%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bij%7D%20=%201,%20%5Cmu%5E%7B(t)%7D,%20%5CSigma%5E%7B(t)%7D)%20%5C,%20%5CPr(%5Cmathbf%7Bz%7D_%7Bij%7D%20=%201%20%7C%20%5Cpi%5E%7B(t)%7D)%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20(%5Ctext%7BBayes'%20rule%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20%20%20%20%20&amp;%20=%20%5Cfrac%7B%5Cpi_%7Bk%7D%20%5C,%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D_%7Bi%7D;%20%5Cmu_%7Bk%7D%5E%7B(t)%7D,%20%5CSigma_%7Bk%7D%5E%7B(t)%7D)%7D%7B%5Csum_%7Bj%20=%201%7D%5E%7BK%7D%20%5Cpi_%7Bj%7D%20%5C,%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D_%7Bi%7D;%20%5Cmu_%7Bj%7D%5E%7B(t)%7D,%20%5CSigma_%7Bj%7D%5E%7B(t)%7D)%7D.%0A%20%20%20%20%20%20%20%20%5Cend%7Baligned%7D%0A%20%20%20%20%7D%0A%5Ctag%7B7%7D"></span></p>
<p><strong>M-step:</strong> maximises the “tighest” lower-bound w.r.t. model parameter <img src="https://latex.codecogs.com/png.latex?(%5Cpi,%20%5Cmu,%20%5CSigma)">: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Coperatorname*%7Bargmax%7D_%7B%5Cpi,%20%5Cmu,%20%5CSigma%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cmathbb%7BE%7D_%7Bq%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bi%7D)%7D%20%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bi%7D,%20%5Cmu,%20%5CSigma)%20+%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D_%7Bi%7D%20%7C%20%5Cpi)%20%5D%20+%20%5Cln%20%5CPr(%5Cpi%20%7C%20%5Calpha)%20+%20%5Cln%20%5CPr(%5Cmu,%20%5CSigma%20%7C%20m,%20%5Clambda,%20%5CPsi,%20%5Cnu)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Coperatorname*%7Bargmax%7D_%7B%5Cmu,%20%5CSigma%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cleft%5B%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201,%20%5Cpi,%20%5Cmu,%20%5CSigma)%20+%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201%7C%20%5Cpi)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20%5Cln%20%5Coperatorname%7BDir%7D(%5Cpi%20%7C%20%5Calpha)%20+%20%5Cln%20%5Coperatorname%7BNIW%7D(%5Cmu_%7Bk%7D,%20%5CSigma_%7Bk%7D%20%7C%20m,%20%5Clambda,%20%5CPsi,%20%5Cnu)%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Coperatorname*%7Bargmax%7D_%7B%5Cmu,%20%5CSigma%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cleft%5B%20%5Cln%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D_%7Bi%7D;%20%5Cmu_%7Bk%7D,%20%5CSigma_%7Bk%7D)%20+%20%5Cln%20%5Cpi_%7Bk%7D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20(%5Calpha_%7Bk%7D%20-%201)%20%5Cln%20%5Cpi_%7Bk%7D%20+%20%5Cln%20%5Cmathcal%7BN%7D%20%5Cleft(%20%5Cmu_%7Bk%7D%20%5Cleft%7C%20m,%20%5Cfrac%7B1%7D%7B%5Clambda%7D%20%5CSigma_%7Bk%7D%20%5Cright.%20%5Cright)%20+%20%5Cln%20%5Cmathcal%7BW%7D%5E%7B-1%7D%20%5Cleft(%20%5CSigma_%7Bk%7D%20%7C%20%5CPsi,%20%5Cnu%20%5Cright)%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Coperatorname*%7Bargmax%7D_%7B%5Cmu,%20%5CSigma%7D%20-%5Cfrac%7B1%7D%7B2%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cleft%5B%20%5Cln%20%5Cleft%7C%20%5CSigma_%7Bk%7D%20%5Cright%7C%20+%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%5E%7B%5Ctop%7D%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%20+%20%5Cln%20%5Cpi_%7Bk%7D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20(%5Calpha_%7Bk%7D%20-%201)%20%5Cln%20%5Cpi_%7Bk%7D%20-%20%5Cfrac%7B%5Cnu%20+%20D%20+%202%7D%7B2%7D%20%5Cln%20%7C%5CSigma_%7Bk%7D%7C%20-%20%5Cfrac%7B1%7D%7B2%7D%20%5Coperatorname%7BTr%7D%20%5Cleft(%20%5CPsi%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20%5Cright)%20-%20%5Cfrac%7B%5Clambda%7D%7B2%7D%20(%5Cmu_%7Bk%7D%20-%20m)%5E%7B%5Ctop%7D%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20(%5Cmu_%7Bk%7D%20-%20m).%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>Taking derivative with respect to <img src="https://latex.codecogs.com/png.latex?%5Cmu_%7Bk%7D"> and setting it to zero give:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%20-%20%5Clambda%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20(%5Cmu_%7Bk%7D%20-%20m)%20=%200%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5CLeftrightarrow%20%5Cleft%5B%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20+%20%5Clambda%20%5Cright%5D%20%5Cmu_%7Bk%7D%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cgamma(%5Cmathbf%7Bz%7D_%7Bik%7D)%20%5Cmathbf%7Bx%7D_%7Bi%7D%20+%20%5Clambda%20m.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>Or:</p>
<p><span id="eq-mu-k"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cboxed%7B%0A%20%20%20%20%20%20%20%20%5Cmu_%7Bk%7D%20=%20%5Cfrac%7B%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cmathbf%7Bx%7D_%7Bi%7D%20+%20%5Clambda%20m%7D%7B%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20+%20%5Clambda%7D.%0A%20%20%20%20%7D%0A%5Ctag%7B8%7D"></span></p>
<p>Similarly for <img src="https://latex.codecogs.com/png.latex?%5CSigma_%7Bk%7D">:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20-%5Cfrac%7B1%7D%7B2%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cleft%5B%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20-%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%5E%7B%5Ctop%7D%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20%5Cfrac%7B1%7D%7B2%7D%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20%5CPsi%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20-%20%5Cfrac%7B%5Cnu%20+%20D%20+%202%7D%7B2%7D%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20+%20%5Cfrac%7B%5Clambda%7D%7B2%7D%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20(%5Cmu_%7Bk%7D%20-%20m)%5E%7B%5Ctop%7D%20(%5Cmu_%7Bk%7D%20-%20m)%20%5CSigma_%7Bk%7D%5E%7B-1%7D%20=%200.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>To solve for <img src="https://latex.codecogs.com/png.latex?%5CSigma_%7Bk%7D">, the covariance matrix itself is used to left- and right-multiply to obtain:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20-%5Cfrac%7B1%7D%7B2%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cleft%5B%20%5CSigma_%7Bk%7D%20-%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%5E%7B%5Ctop%7D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20%5Cfrac%7B1%7D%7B2%7D%20%5CPsi%20-%20%5Cfrac%7B%5Cnu%20+%20D%20+%202%7D%7B2%7D%20%5CSigma_%7Bk%7D%20+%20%5Cfrac%7B%5Clambda%7D%7B2%7D%20(%5Cmu_%7Bk%7D%20-%20m)%5E%7B%5Ctop%7D%20(%5Cmu_%7Bk%7D%20-%20m)%20=%200%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5CLeftrightarrow%20%5Cleft%5B%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20+%20%5Cnu%20+%20D%20+%202%20%5Cright%5D%20%5CSigma_%7Bk%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%5E%7B%5Ctop%7D%20+%20%5CPsi%20+%20%5Clambda%20(%5Cmu_%7Bk%7D%20-%20m)%5E%7B%5Ctop%7D%20(%5Cmu_%7Bk%7D%20-%20m).%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>Or:</p>
<p><span id="eq-sigma-k"><img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20%5CSigma_%7Bk%7D%20=%20%5Cfrac%7B%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%20(%5Cmathbf%7Bx%7D_%7Bi%7D%20-%20%5Cmu_%7Bk%7D)%5E%7B%5Ctop%7D%20+%20%5CPsi%20+%20%5Clambda%20(%5Cmu_%7Bk%7D%20-%20m)%5E%7B%5Ctop%7D%20(%5Cmu_%7Bk%7D%20-%20m)%7D%7B%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20+%20%5Cnu%20+%20D%20+%202%7D.%0A%7D%0A%5Ctag%7B9%7D"></span></p>
<p>One can further substitute <img src="https://latex.codecogs.com/png.latex?%5Cmu_%7Bk%7D"> in Equation&nbsp;8 into Equation&nbsp;9 to obtain an expression for <img src="https://latex.codecogs.com/png.latex?%5CSigma_%7Bk%7D"> that only depends on observed data <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"> and prior parameters.</p>
<p>Finally, one can obtain the optimal value for the mixture coefficient <img src="https://latex.codecogs.com/png.latex?%5Cpi_%7Bk%7D"> in a similar way, except it is now a constrained optimisation. Such an optimisation can be written as follows:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20&amp;%20%5Cmax_%7B%5Cpi%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cleft%5B%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20+%20%5Calpha_%7Bk%7D%20-%201%20%5Cright%5D%20%5Cln%20%5Cpi_%7Bk%7D%20%5C%5C%0A%20%20%20%20&amp;%20%5Ctext%7Bsubject%20to:%20%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cpi_%7Bk%7D%20=%201.%0A%5Cend%7Baligned%7D%0A"></p>
<p>The constrained optimisation above can simly be solved by Lagrange multiplier. The result for <img src="https://latex.codecogs.com/png.latex?%5Cpi_%7Bk%7D"> can then be expressed as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20%5Cpi_%7Bk%7D%20=%20%5Cfrac%7B%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20+%20%5Calpha_%7Bk%7D%20-%201%7D%7BN%20-%20K%20+%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Calpha_%7Bk%7D%7D.%0A%7D%0A"></p>
<p>One can also refer to Chapter 10.2 in <span class="citation" data-cites="bishop2006pattern">(Bishop 2006)</span> for a similar derivation and result.</p>
</section>
</section>
<section id="multinomial-mixture-models" class="level3" data-number="3.2">
<h3 data-number="3.2" class="anchored" data-anchor-id="multinomial-mixture-models"><span class="header-section-number">3.2</span> Multinomial mixture models</h3>
<p>Similar to the Gaussian mixture models, a multinomial mixture model can also be written as:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cpi,%20m,%20%5Crho)%20=%20%5Csum_%7B%5Cmathbf%7Bz%7D%7D%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cpi)%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cmathbf%7Bz%7D,%20m,%20%5Crho)%20=%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cpi_%7Bk%7D%20%5Cmathrm%7BMult%7D(%5Cmathbf%7Bx%7D;%20m,%20%5Crho_%7Bk%7D).%0A"></p>
<div class="callout callout-style-default callout-note callout-titled" title="$m$ is given">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span><img src="https://latex.codecogs.com/png.latex?m"> is given
</div>
</div>
<div class="callout-body-container callout-body">
<p>Only the case where all the multinomial components have the same parameter <img src="https://latex.codecogs.com/png.latex?m"> (the number of trials) are considered. The reason is that optimising for an integer number <img src="https://latex.codecogs.com/png.latex?m"> is beyond the scope of this post.</p>
</div>
</div>
<section id="data-generation-1" class="level4" data-number="3.2.1">
<h4 data-number="3.2.1" class="anchored" data-anchor-id="data-generation-1"><span class="header-section-number">3.2.1</span> Data generation</h4>
<p>A data-point of the multinomial mixture model can be generated as follows:</p>
<ul>
<li>sample a probability <img src="https://latex.codecogs.com/png.latex?%5Cpi"> from a Dirichlet prior: <img src="https://latex.codecogs.com/png.latex?%5Cpi%20%5Csim%20%5CPr(%5Cpi%20%7C%20%5Calpha)%20=%20%5Coperatorname%7BDir%7D(%5Cpi%20%7C%20%5Calpha)">,</li>
<li>sample <img src="https://latex.codecogs.com/png.latex?K"> probability vectors, <img src="https://latex.codecogs.com/png.latex?%5C%7B%20%5Crho_%7Bk%7D%20%5C%7D_%7Bk%20=%201%7D%5E%7BK%7D)">, from a Dirichlet prior: <img src="https://latex.codecogs.com/png.latex?%5Crho_%7Bk%7D%20%5Csim%20%5CPr(%5Crho%20%7C%20%5Cbeta%20)%20=%20%5Coperatorname%7BDir%7D(%5Crho%20%7C%20%5Cbeta)">,</li>
<li>sample the index of a multinomial component: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D%20%5Csim%20%5CPr(%5Cmathbf%7Bz%7D%20%7C%20%5Cpi)%20=%20%5Coperatorname%7BCategorical%7D(%5Cmathbf%7Bz%7D%20%7C%20%5Cpmb%7B%5Cpi%7D)">, then</li>
<li>sample a data-point from the corresponding multinomial component: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D%20%5Csim%20%5CPr(%5Cmathbf%7Bx%7D%20%7C%20%5Cmathbf%7Bz%7D,%20%5Crho)%20=%20%5Coperatorname%7BMultinomial%7D(%5Cmathbf%7Bx%7D%7C%20%5Crho_%7Bk%7D)">, where <img src="https://latex.codecogs.com/png.latex?z_%7Bk%7D%20=%201">.</li>
</ul>
<p>The data generation process can also be visualised in the graphical model shown below.</p>
<div class="cell" data-layout-align="default">
<div class="cell-output-display">
<div>
<p></p><figure class="figure"><p></p>
<div>
<pre class="mermaid mermaid-js">%%{
    init: {
        'theme': 'base',
        'themeVariables': {
            'primaryColor': '#ffffff'
        }
    }
}%%
flowchart LR
    subgraph data["data"]
        direction LR
        z((z)):::rv --&gt; x((x)):::rv
    end
    alpha((α)):::notfilled --&gt; pi((π)):::params --&gt; z;
    beta((β)):::params --&gt; rho((ρ)):::params;
    rho --&gt; x;

    style z fill: none
    classDef params stroke: #000, fill: none
    classDef rv stroke: #000
    classDef notfilled fill: none
    linkStyle default stroke: #000
    style data fill: none
</pre>
</div>
<p></p></figure><p></p>
</div>
</div>
</div>
</section>
<section id="objective-1" class="level4" data-number="3.2.2">
<h4 data-number="3.2.2" class="anchored" data-anchor-id="objective-1"><span class="header-section-number">3.2.2</span> Objective</h4>
<p>Given set of data-points <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Cmathbf%7Bx%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BN%7D"> sampled from a multinomial mixture distribution, the aim is to infer the point estimate, and in particular MAP, of <img src="https://latex.codecogs.com/png.latex?(%5Cpi,%20%5Crho)"> as follows:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmax_%7B%5Cpi,%20%5Crho%7D%20%5Cln%20%5CPr(%5Cpi,%20%5Crho%20%7C%20%5C%7B%5Cmathbf%7Bx%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BN%7D,%20%5Calpha,%20m,%20%5Cbeta)%20=%20%5Cmax_%7B%5Cpi,%20%5Crho%7D%20%5Cfrac%7B1%7D%7BN%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cpi,%20m,%20%5Crho)%20+%20%5Cln%20%5Coperatorname%7BDir%7D(%5Cpi%20%7C%20%5Calpha)%20+%20%5Cln%20%5Coperatorname%7BDir%7D(%5Crho%20%7C%20%5Cbeta).%0A"></p>
</section>
<section id="parameter-inference-with-em" class="level4" data-number="3.2.3">
<h4 data-number="3.2.3" class="anchored" data-anchor-id="parameter-inference-with-em"><span class="header-section-number">3.2.3</span> Parameter inference with EM</h4>
<p><strong>E-step</strong> calculates the posterior of the latent variable <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D_%7Bi%7D"> given the data <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D_%7Bi%7D">: <span id="eq-mmm_e_step"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20&amp;%20=%20%5CPr(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201%20%7C%20%5Cmathbf%7Bx%7D_%7Bi%7D,%20%5Cpi%5E%7B(t)%7D,%20%5Crho%5E%7B(t)%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cfrac%7B%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201,%20%5Crho%5E%7B(t)%7D)%20%5C,%20%5CPr(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201%20%7C%20%5Cpi%5E%7B(t)%7D)%7D%7B%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201,%20%5Crho%5E%7B(t)%7D)%20%5C,%20%5CPr(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201%20%7C%20%5Cpi%5E%7B(t)%7D)%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cfrac%7B%5Cpi_%7Bk%7D%5E%7B(t)%7D%20%5C,%20%5Cmathrm%7BMult%7D(%5Cmathbf%7Bx%7D_%7Bi%7D;%20m,%20%5Crho_%7Bk%7D%5E%7B(t)%7D)%7D%7B%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cpi_%7Bk%7D%5E%7B(t)%7D%20%5C,%20%5Cmathrm%7BMult%7D(%5Cmathbf%7Bx%7D_%7Bi%7D;%20m,%20%5Crho_%7Bk%7D%5E%7B(t)%7D)%7D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A%5Ctag%7B10%7D"></span></p>
<p><strong>M-step</strong> In the M-step, we maximise the following expected completed log-likelihood w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cpi"> and <img src="https://latex.codecogs.com/png.latex?%5Crho">:</p>
<p><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Coperatorname*%7Bargmax%7D_%7B%5Cpi,%20%5Crho%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cmathbb%7BE%7D_%7Bq%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bi%7D)%7D%20%5B%20%5Cln%20%5CPr(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20%5Cmathbf%7Bz%7D_%7Bi%7D,%20m,%20%5Crho)%20+%20%5Cln%20%5CPr(%5Cmathbf%7Bz%7D_%7Bi%7D%20%7C%20%5Cpi)%20%5D%20+%20%5Cln%20%5CPr(%5Cpi%20%7C%20%5Calpha)%20+%20%5Cln%20%5CPr(%5Crho%20%7C%20%5Cbeta)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Coperatorname*%7B%5Cargmax%7D_%7B%5Cpi,%20%5Crho%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cmathbb%7BE%7D_%7Bq%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bi%7D)%7D%20%5Cleft%5B%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cmathbf%7Bz%7D_%7Bik%7D%20%5Cln%20%5Coperatorname%7BMult%7D(%5Cmathbf%7Bx%7D_%7Bi%7D%20%7C%20m,%20%5Crho_%7Bk%7D)%20+%20%5Cln%20%5Coperatorname%7BCategorical%7D(%5Cmathbf%7Bz%7D_%7Bi%7D%20%7C%20%5Cpi)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20%5Cln%20%5Coperatorname%7BDir%7D(%5Cpi%20%7C%20%5Calpha)%20+%20%5Cln%20%5Coperatorname%7BDir%7D(%5Crho%20%7C%20%5Cbeta)%20%5C%5C&amp;%20=%20%5Coperatorname*%7B%5Cargmax%7D_%7B%5Cpi,%20%5Crho%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cmathbb%7BE%7D_%7Bq%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bi%7D)%7D%20%5Cleft%5B%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cmathbf%7Bz%7D_%7Bik%7D%20%5Cleft(%20%5Csum_%7Bd%20=%201%7D%5E%7BD%7D%20%5Cmathbf%7Bx%7D_%7Bid%7D%20%5Cln%20%5Crho_%7Bkd%7D%20%5Cright)%20+%20%5Cmathbf%7Bz%7D_%7Bik%7D%20%5Cln%20%5Cpi_%7Bk%7D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20(%5Calpha%20-%201)%20%5Cln%20%5Cpi_%7Bk%7D%20+%20(%5Cbeta%20-%201)%20%5Cln%20%5Crho_%7Bk%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Coperatorname*%7B%5Cargmax%7D_%7B%5Cpi,%20%5Crho%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cleft%5B%20%5Cln%20%5Cpi_%7Bk%7D%20+%20%5Csum_%7Bd%20=%201%7D%5E%7BD%7D%20%5Cmathbf%7Bx%7D_%7Bid%7D%20%5Cln%20%5Crho_%7Bkd%7D%20%5Cright%5D%20+%20(%5Calpha%20-%201)%20%5Cln%20%5Cpi_%7Bk%7D%20+%20(%5Cbeta%20-%201)%20%5Csum_%7Bd%20=%201%7D%5E%7BD%7D%20%5Cln%20%5Crho_%7Bkd%7D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<div class="callout callout-style-default callout-note callout-titled" title="Probability constrains on $\pi$ and $\rho$">
<div class="callout-header d-flex align-content-center">
<div class="callout-icon-container">
<i class="callout-icon"></i>
</div>
<div class="callout-title-container flex-fill">
<span class="screen-reader-only">Note</span>Probability constrains on <img src="https://latex.codecogs.com/png.latex?%5Cpi"> and <img src="https://latex.codecogs.com/png.latex?%5Crho">
</div>
</div>
<div class="callout-body-container callout-body">
<p>Due to the nature of a multinomial mixture model, both the parameters <img src="https://latex.codecogs.com/png.latex?%5Cpi"> and <img src="https://latex.codecogs.com/png.latex?%5Crho"> are probability vectors.</p>
</div>
</div>
<p>The Lagrangian for <img src="https://latex.codecogs.com/png.latex?%5Cpi"> can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathsf%7BL%7D_%7B%5Cpi%7D%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cln%20%5Cpi_%7Bk%7D%20+%20(%5Calpha%20-%201)%20%5Cln%20%5Cpi_%7Bk%7D%20-%20%5Clambda%20%5Cleft(%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cpi_%7Bk%7D%20-%201%20%5Cright),%0A"> where <img src="https://latex.codecogs.com/png.latex?%5Clambda"> is the Lagrange multiplier.</p>
<p>Taking derivative of the Lagrangian w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cpi_%7Bk%7D"> gives: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20%5Cmathsf%7BL%7D_%7B%5Cpi%7D%7D%7B%5Cpartial%20%5Cpi_%7Bk%7D%7D%20=%20%5Cfrac%7B1%7D%7B%5Cpi_%7Bk%7D%7D%20%5Cleft%5B%20%5Calpha%20-%201%20+%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cright%5D%20-%20%5Clambda.%0A"></p>
<p>Setting the derivative to zero and solving for <img src="https://latex.codecogs.com/png.latex?%5Cpi_%7Bk%7D"> gives: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cpi_%7Bk%7D%20=%20%5Cfrac%7B1%7D%7B%5Clambda%7D%20%5Cleft%5B%20%5Calpha%20-%201%20+%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cright%5D.%0A"></p>
<p>And since <img src="https://latex.codecogs.com/png.latex?%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Cpi_%7Bk%7D%20=%201">, one can substitute and find that <img src="https://latex.codecogs.com/png.latex?%5Clambda%20=%20N%20+%20K%20(%5Calpha%20-%201)">. Thus: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cboxed%7B%0A%20%20%20%20%20%20%20%20%5Cpi_%7Bk%7D%5E%7B(t%20+%201)%7D%20=%20%5Cfrac%7B%5Calpha%20-%201%20+%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%7D%7BN%20+%20K%20(%5Calpha%20-%201)%7D.%0A%20%20%20%20%7D%0A"></p>
<p>Similarly, the Lagrangian of <img src="https://latex.codecogs.com/png.latex?%5Crho"> can be expressed as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathsf%7BL%7D_%7B%5Crho%7D%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Csum_%7Bd%20=%201%7D%5E%7BD%7D%20%5Cmathbf%7Bx%7D_%7Bid%7D%20%5Cln%20%5Crho_%7Bkd%7D%20+%20(%5Cbeta%20-%201)%20%5Cln%20%5Crho_%7Bkd%7D%20-%20%5Csum_%7Bk%20=%201%7D%5E%7BK%7D%20%5Ceta_%7Bk%7D%20%5Cleft(%20%5Csum_%7Bd%20=%201%7D%5E%7BD%7D%20%5Crho_%7Bkd%7D%20-%201%20%5Cright),%0A"> where <img src="https://latex.codecogs.com/png.latex?%5Ceta_%7Bk%7D"> is the Lagrange multiplier. Taking derivative w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Crho_%7Bkd%7D"> gives: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20%5Cmathsf%7BL%7D_%7B%5Crho%7D%7D%7B%5Cpartial%20%5Crho_%7Bkd%7D%7D%20=%20%5Cfrac%7B1%7D%7B%5Crho_%7Bkd%7D%7D%20%5Cleft%5B%20%5Cbeta%20-%201%20+%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cmathbf%7Bx%7D_%7Bid%7D%20%5Cright%5D%20-%20%5Ceta_%7Bk%7D.%0A"> Setting the derivative to zero and solving for <img src="https://latex.codecogs.com/png.latex?%5Crho_%7Bkd%7D"> gives: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Crho_%7Bkd%7D%20=%20%5Cfrac%7B1%7D%7B%5Ceta_%7Bk%7D%7D%20%5Cleft%5B%20%5Cbeta%20-%201%20+%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cmathbf%7Bx%7D_%7Bid%7D%20%5Cright%5D.%0A"> The constraint on <img src="https://latex.codecogs.com/png.latex?%5Crho_%7Bk%7D"> as a probability vector leads to <img src="https://latex.codecogs.com/png.latex?%5Ceta_%7Bk%7D%20=%20K%20(%5Cbeta%20-%201)%20+%20m%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)">. Thus: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cboxed%7B%0A%20%20%20%20%20%20%20%20%5Crho_%7Bkd%7D%5E%7B(t%20+%201)%7D%20=%20%5Cfrac%7B%5Cbeta%20-%201%20+%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%20%5Cmathbf%7Bx%7D_%7Bid%7D%7D%7BK%20(%5Cbeta%20-%201)%20+%20m%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20q%5E%7B*%7D(%5Cmathbf%7Bz%7D_%7Bik%7D%20=%201)%7D.%0A%20%20%20%20%7D%0A"></p>
<p>One can also refer to <span class="citation" data-cites="elmore2003identifiability">(Elmore and Wang 2003)</span> for a similar derivation and result.</p>
</section>
</section>
</section>
<section id="references" class="level2" data-number="4">
<h2 data-number="4" class="anchored" data-anchor-id="references"><span class="header-section-number">4</span> References</h2>
<div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-bernardo2003variational" class="csl-entry">
<span class="nocase">Bernardo, JM, MJ Bayarri, JO Berger, et al.</span> 2003. <span>“The Variational Bayesian EM Algorithm for Incomplete Data: With Application to Scoring Graphical Model Structures.”</span> <em>Bayesian Statistics</em> 7 (453-464): 210.
</div>
<div id="ref-bishop2006pattern" class="csl-entry">
Bishop, Christopher M. 2006. <em>Pattern Recognition and Machine Learning</em>. Vol. 4. Springer.
</div>
<div id="ref-dempster1977maximum" class="csl-entry">
Dempster, Arthur P, Nan M Laird, and Donald B Rubin. 1977. <span>“Maximum Likelihood from Incomplete Data via the EM Algorithm.”</span> <em>Journal of the Royal Statistical Society: Series B (Methodological)</em> 39 (1): 1–22.
</div>
<div id="ref-elmore2003identifiability" class="csl-entry">
Elmore, Ryan T, and Shaoli Wang. 2003. <em>Identifiability and Estimation in Finite Mixture Models with Multinomial Components</em>. Technical Report 03-04, Pennsylvania State University.
</div>
</div>


<!-- -->

</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a><div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by/4.0/">CC BY 4.0</a></div></div></section><section class="quarto-appendix-contents" id="quarto-citation"><h2 class="anchored quarto-appendix-heading">Citation</h2><div><div class="quarto-appendix-secondary-label">BibTeX citation:</div><pre class="sourceCode code-with-copy quarto-appendix-bibtex"><code class="sourceCode bibtex">@online{nguyen2022,
  author = {Nguyen, Cuong},
  title = {Expectation - {Maximisation} Algorithm and Its Applications
    in Finite Mixture Models},
  date = {2022-07-17},
  url = {https://cnguyen10.github.io/posts/mixture-models/},
  langid = {en}
}
</code></pre><div class="quarto-appendix-secondary-label">For attribution, please cite this work as:</div><div id="ref-nguyen2022" class="csl-entry quarto-appendix-citeas">
Nguyen, Cuong. 2022. <span>“Expectation - Maximisation Algorithm and Its
Applications in Finite Mixture Models.”</span> July 17. <a href="https://cnguyen10.github.io/posts/mixture-models/">https://cnguyen10.github.io/posts/mixture-models/</a>.
</div></div></section></div> ]]></description>
  <category>Expectation-Maximization</category>
  <category>Clustering</category>
  <category>Latent Variables</category>
  <guid>https://cnguyen10.github.io/posts/mixture-models/</guid>
  <pubDate>Sun, 17 Jul 2022 00:00:00 GMT</pubDate>
</item>
<item>
  <title>Bias - variance decomposition</title>
  <dc:creator>Cuong Nguyen</dc:creator>
  <link>https://cnguyen10.github.io/posts/bias-variance-decomposition/</link>
  <description><![CDATA[ 




<p>Bias and variance decomposition is one of the key tools to understand machine learning. However, conventional discussion about bias - variance decomposition revolves around the square loss (also known as mean square error). It is unclear whether such decomposition is still valid for some common loss functions, such as 0-1 loss or cross-entropy loss used in classification. This post is to present the decomposition for those losses following the <em>unified</em> framework of bias and variance decomposition from <span class="citation" data-cites="domingos2000unified">(Domingos 2000)</span>, its extended study on <em>Bregman divergence</em> with <em>un-bounded support</em> from <span class="citation" data-cites="pfau2025generalized">(Pfau 2025)</span> and the special case about Kullback-Leibler (KL) divergence <span class="citation" data-cites="heskes1998bias">(Heskes 1998)</span>.</p>
<section id="notations" class="level2" data-number="1">
<h2 data-number="1" class="anchored" data-anchor-id="notations"><span class="header-section-number">1</span> Notations</h2>
<p>The notations are similar to the ones in <span class="citation" data-cites="domingos2000unified">(Domingos 2000)</span>, but for <img src="https://latex.codecogs.com/png.latex?C">-class classification.</p>
<table class="table-striped table-hover caption-top table">
<caption>Notations used in the bias-variance decomposition.</caption>
<colgroup>
<col style="width: 25%">
<col style="width: 75%">
</colgroup>
<thead>
<tr class="header">
<th>Notation</th>
<th>Description</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"></td>
<td>an input instance in <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BX%7D%20%5Csubseteq%20%5Cin%20%5Cmathbb%7BR%7D%5E%7Bd%7D"></td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5CDelta_%7BK%7D"></td>
<td>the <img src="https://latex.codecogs.com/png.latex?K">-dimensional simplex <img src="https://latex.codecogs.com/png.latex?%5Cequiv%20%5C%7B%5Cmathbf%7Bv%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BK%20+%201%7D_%7B+%7D:%20%5Cmathbf%7Bv%7D%5E%7B%5Ctop%7D%20%5Cpmb%7B1%7D%20=%201%5C%7D"></td>
</tr>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5CDelta_%7BK%7D"></td>
<td>the <img src="https://latex.codecogs.com/png.latex?K">-dimensional simplex <img src="https://latex.codecogs.com/png.latex?%5Cequiv%20%5C%7B%5Cmathbf%7Bv%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BK%20+%201%7D_%7B+%7D:%20%5Cmathbf%7Bv%7D%5E%7B%5Ctop%7D%20%5Cpmb%7B1%7D%20=%201%5C%7D"></td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D"></td>
<td>a label instance: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D%20%5Csim%20%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)">, for example: (i) one-hot vector if <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)"> is a categorical distribution, or (ii) soft-label if <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)"> is a Dirichlet or logistic normal distribution</td>
</tr>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Cell"></td>
<td>loss function <img src="https://latex.codecogs.com/png.latex?%5Cell:%20%5CDelta_%7BC%20-%201%7D%20%5Ctimes%20%5CDelta_%7BC%20-%201%7D%20%5Cto%20%5B0,%20+%5Cinfty%5D">, e.g.&nbsp;0-1 loss or cross-entropy loss</td>
</tr>
<tr class="even">
<td><img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D"></td>
<td>predicted label distribution: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20f(%5Cmathbf%7Bx%7D)%20%5Cin%20%5CDelta_%7BC%20-%201%7D"></td>
</tr>
<tr class="odd">
<td><img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D"></td>
<td>the set of training sets</td>
</tr>
</tbody>
</table>
</section>
<section id="terminologies" class="level2" data-number="2">
<h2 data-number="2" class="anchored" data-anchor-id="terminologies"><span class="header-section-number">2</span> Terminologies</h2>
<div id="def-optimal-prediction" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 1</strong></span> The optimal prediction <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7B*%7D%20%5Cin%20%5CDelta_%7BC%20-%201%7D"> of a target <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D"> is defined as follows: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbf%7By%7D_%7B*%7D%20=%20%5Coperatorname*%7Bargmin%7D_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20%5Cright)%20%5Cright%5D.%0A"></p>
</div>
<div id="def-main-model-prediction" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 2</strong></span> The main model prediction for a loss function, <img src="https://latex.codecogs.com/png.latex?%5Cell">, and the set of training sets, <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D">, is defined as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbf%7By%7D_%7Bm%7D%20=%20%5Coperatorname*%7Bargmin%7D_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cleft%5B%20%5Cell%20%5Cleft(%5Cmathbf%7By%7D,%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20%5Cright)%20%5Cright%5D.%0A"></p>
</div>
<div class="proof remark">
<p><span class="proof-title"><em>Remark</em>. </span>The defintions of <em>optimal</em> and <em>main model</em> predictions above assume that the loss function <img src="https://latex.codecogs.com/png.latex?%5Cell"> is symmetric in terms of the input arguments. For asymmetric loss function, such as Bregmand divergence or cross-entropy, the definitions of such predictions might be slightly changed at the order of the input arguments.</p>
</div>
<p>Given the definitions of <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7B*%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D">, the bias, variance and noise can be defined following the <em>unified</em> framework proposed in <span class="citation" data-cites="domingos2000unified">(Domingos 2000)</span> as follows:</p>
<div id="def-bias" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 3</strong></span> The bias of a learner on an example <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"> is defined as: <img src="https://latex.codecogs.com/png.latex?B(%5Cmathbf%7Bx%7D)%20=%20%5Cell%20%5Cleft(%20%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D%20%5Cright)">.</p>
</div>
<div id="def-variance" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 4</strong></span> The variance of a learner on an example <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"> is defined as: <img src="https://latex.codecogs.com/png.latex?V(%5Cmathbf%7Bx%7D)%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cleft%5B%20%5Cell%20%5Cleft(%20%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D%20%5Cright)%20%5Cright%5D">.</p>
</div>
<div id="def-noise" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 5</strong></span> The noise of an example <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"> is defined as: <img src="https://latex.codecogs.com/png.latex?N(%5Cmathbf%7Bx%7D)%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%20%5Cright%5D">.</p>
</div>
<p>The definitions of bias and variance above are quite intuitive comparing to other definitions in the literature. As <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D"> is the <em>main</em> model prediction, the bias <img src="https://latex.codecogs.com/png.latex?B(%5Cmathbf%7Bx%7D)"> measures the systematic deviation (loss) from the <em>optimal</em> (or true) label <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7B*%7D">, while the variance <img src="https://latex.codecogs.com/png.latex?V(%5Cmathbf%7Bx%7D)"> measures the loss induced due to the fluctuations of each model prediction <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D"> on different training datasets around the <em>main</em> prediction <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D">. In addition, as the loss <img src="https://latex.codecogs.com/png.latex?%5Cell"> is non-negative, both the bias and variance are also non-negative.</p>
<p>Given the defintions of bias, variance and noise above, the unified decomposition proposed in <span class="citation" data-cites="domingos2000unified">(Domingos 2000)</span> can be expressed as: <span id="eq-unified_decomposition"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20&amp;%20=%20%5Ctextcolor%7BCrimson%7D%7B%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%7D%20+%20c_%7B1%7D%20%5C,%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%5B%5Cell(%5Cmathbf%7By%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%5D%7D%20+%20c_%7B2%7D%20%5C,%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By_%7B*%7D%7D)%5D%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Ctextcolor%7BCrimson%7D%7BB(%5Cmathbf%7Bx%7D)%7D%20+%20c_%7B1%7D%20%5C,%20%5Ctextcolor%7BMidnightBlue%7D%7BV(%5Cmathbf%7Bx%7D)%7D%20+%20c_%7B2%7D%20%5C,%20%5Ctextcolor%7BGreen%7D%7BN(%5Cmathbf%7Bx%7D)%7D,%0A%20%20%20%20%5Cend%7Baligned%7D%0A%5Ctag%7B1%7D"></span> where <img src="https://latex.codecogs.com/png.latex?c_%7B1%7D"> and <img src="https://latex.codecogs.com/png.latex?c_%7B2%7D"> are two scalars. For example, in MSE, <img src="https://latex.codecogs.com/png.latex?c_%7B1%7D%20=%20c_%7B2%7D%20=%201">.</p>
<p>Of course, not all losses would satisfy the decomposition in Equation&nbsp;1. However, as shown in <span class="citation" data-cites="domingos2000unified">(Domingos 2000 - Theorem 7)</span>, such decomposition can be used to bound the expected loss as long as the loss is metric. Nevertheless, in this post, we dicuss the composition on some common loss functions, such as 0-1 loss and Bregman divergence which includes MSE and Kullback-Leibler (KL) divergence.</p>
</section>
<section id="square-loss" class="level2" data-number="3">
<h2 data-number="3" class="anchored" data-anchor-id="square-loss"><span class="header-section-number">3</span> Square loss</h2>
<p>To warm-up, we discuss a wellknown bias-variance decomposition in the literature. It is applied for MSE or square loss. Here, we use the notations of vectors instead of scalars as often seen in conventional analysis. We will derive a general decomposition for Bregman divergence in which MSE is a particular case in a later section.</p>
<div id="thm-mse" class="theorem">
<p><span class="theorem-title"><strong>Theorem 1</strong></span> When the loss is the square loss: <img src="https://latex.codecogs.com/png.latex?%5Cell(%5Cmathbf%7By%7D_%7B1%7D,%20%5Cmathbf%7By%7D_%7B2%7D)%20=%20%7C%7C%20%5Cmathbf%7By%7D_%7B1%7D%20-%20%5Cmathbf%7By%7D_%7B2%7D%7C%7C_%7B2%7D%5E%7B2%7D">, then the expected loss on several training sets can be decomposed into: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20&amp;%20=%20%5Ctextcolor%7BCrimson%7D%7B%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%7D%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cell(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D%7D%20+%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20%5Cell(%20%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D%20)%5D%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20%5Ctext%7Bor:%20%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%20&amp;%20=%20%5Cunderbrace%7B%5Ctextcolor%7BCrimson%7D%7B%7C%7C%20%5Cmathbf%7By%7D_%7B*%7D%20-%20%5Cmathbf%7By%7D_%7Bm%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%7D%7D_%7B%5Ctext%7Bbias%7D%7D%20+%20%5Cunderbrace%7B%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%7C%7C%20%5Cmathbf%7By%7D_%7Bm%7D%20-%20%5Cmathbf%7By%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%7D%7D_%7B%5Ctext%7Bvariance%7D%7D%20+%20%5Cunderbrace%7B%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D_%7B*%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%7D%7D_%7B%5Ctext%7Bnoise%7D%7D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
</div>
<details>
<summary>
Please refer to the detailed proof here
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>Given the square loss, the <em>optimal</em> prediction can be determined as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%20%5Cge%20%7C%7C%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20%5Cmathbf%7Bt%7D%20%5Cright%5D%20-%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%20%5Cge%200%20%5Cquad%20%5Ctext%7B(Jensen's%20inequality%20on%20L2-norm)%7D%5C%5C%0A%20%20%20%20%20%20%20%20%5Cimplies%20&amp;%20%5Cmathbf%7By%7D_%7B*%7D%20=%20%5Coperatorname*%7Bargmin%7D_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> Similarly, the <em>main</em> model prediction can be obtained as: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathbf%7By%7D%5D">.</p>
<p>The expected loss can then be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D)%5E%7B%5Ctop%7D%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft(%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%20+%20(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathbf%7By%7D%5D)%20+%20(%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathbf%7By%7D%5D%20-%20%5Cmathbf%7By%7D)%20%5Cright)%5E%7B%5Ctop%7D%20%5Cleft(%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%20%5Cright.%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20%5Cleft.%20+%20(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathbf%7By%7D%5D)%20+%20(%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathbf%7By%7D%5D%20-%20%5Cmathbf%7By%7D)%20%5Cright)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20%7C%7C_%7B2%7D%5E%7B2%7D%20+%20%7C%7C%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathbf%7By%7D%5D%20%7C%7C_%7B2%7D%5E%7B2%7D%20+%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%7C%7C%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathbf%7By%7D%5D%20-%20%5Cmathbf%7By%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D_%7B*%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%7D%20+%20%5Ctextcolor%7BCrimson%7D%7B%7C%7C%20%5Cmathbf%7By%7D_%7B*%7D%20-%20%5Cmathbf%7By%7D_%7Bm%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%7D%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%7C%7C%20%5Cmathbf%7By%7D_%7Bm%7D%20-%20%5Cmathbf%7By%7D%20%7C%7C_%7B2%7D%5E%7B2%7D%7D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
</div>
</details>
</section>
<section id="loss" class="level2" data-number="4">
<h2 data-number="4" class="anchored" data-anchor-id="loss"><span class="header-section-number">4</span> 0-1 loss</h2>
<div id="def-0-1-loss" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 6</strong></span> The 0-1 loss is defined as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cell(%5Cmathbf%7By%7D_%7B1%7D,%20%5Cmathbf%7By%7D_%7B2%7D)%20=%20%5CBbb%7B1%7D%20(%5Cmathbf%7By%7D_%7B1%7D,%20%5Cmathbf%7By%7D_%7B2%7D)%20=%20%5Cbegin%7Bcases%7D%0A%20%20%20%20%20%20%20%200%20&amp;%20%5Ctext%7Bif%20%7D%20%5Cmathbf%7By%7D_%7B1%7D%20=%20%5Cmathbf%7By%7D_%7B2%7D,%5C%5C%0A%20%20%20%20%20%20%20%201%20&amp;%20%5Ctext%7Bif%20%7D%20%5Cmathbf%7By%7D_%7B1%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B2%7D.%0A%20%20%20%20%5Cend%7Bcases%7D%0A"></p>
</div>
<section id="binary-classification" class="level3" data-number="4.1">
<h3 data-number="4.1" class="anchored" data-anchor-id="binary-classification"><span class="header-section-number">4.1</span> Binary classification</h3>
<div id="thm-binary-0-1-loss" class="theorem">
<p><span class="theorem-title"><strong>Theorem 2</strong></span> (<span class="citation" data-cites="domingos2000unified">(Domingos 2000 - Theorem 2)</span>) The expected 0-1 loss in a <strong>binary classification</strong> setting can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20%5Cright%5D%20=%20%5Ctextcolor%7BCrimson%7D%7B%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%7D%20+%20%5Ctextcolor%7BBrown%7D%7Bc%7D%20%5C,%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cleft%5B%20%5Cell(%20%5Cmathbf%7By%7D,%20%5Cmathbf%7By%7D_%7Bm%7D%20)%20%5Cright%5D%7D%20+%20%5Cleft%5B%202%20%5Coperatorname%7BPr%7D_%7B%5Cmathcal%7BD%7D%7D(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D)%20-%201%20%5Cright%5D%20%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%20%5Cright%5D%7D,%0A"> where: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Ctextcolor%7BBrown%7D%7Bc%7D%20=%20%5Cbegin%7Bcases%7D%0A%20%20%20%20%20%20%20%20+1%20&amp;%20%5Ctext%7Bif%20%7D%20%5Cmathbf%7By%7D_%7Bm%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D%5C%5C%0A%20%20%20%20%20%20%20%20-1%20&amp;%20%5Ctext%7Botherwise%7D.%0A%20%20%20%20%5Cend%7Bcases%7D%0A"></p>
</div>
<details>
<summary>
The proof is copied in <span class="citation" data-cites="domingos2000unified">(Domingos 2000 - Theorem 2)</span> for a self-contained discussion.
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>To prove the theorem, we calculate <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D">, then combine both of them to complete the proof.</p>
<p>First, we proceed to prove the followings: <span id="eq-expected_01_wrt_t"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20=%20%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%20+%20c_%7B0%7D%20%5C,%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D,%0A%5Ctag%7B2%7D"></span> with <img src="https://latex.codecogs.com/png.latex?c_%7B0%7D%20=%201"> if <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D"> and <img src="https://latex.codecogs.com/png.latex?c_%7B0%7D%20=%20-1"> if <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D">.</p>
<p>If <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D">, then Equation&nbsp;2 is trivially true with <img src="https://latex.codecogs.com/png.latex?c_%7B0%7D%20=%201">. We next prove Equation&nbsp;2 when <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D">. Since there are only two classes, if <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D">, then <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D"> and vice versa. And since two events are equivalent, <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D)%20=%20%5CPr(%5Cmathbf%7Bt%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D)">. The expected 0-1 loss w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D"> can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20&amp;%20=%20%5CPr(%5Cmathbf%7Bt%7D%20=%20%5Cmathbf%7By%7D)%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%201%20-%20%5CPr(%5Cmathbf%7Bt%7D%20%5Cneq%20%5Cmathbf%7By%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%201%20-%20%5CPr(%5Cmathbf%7Bt%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%201%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%20%5D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%20%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> This proves Equation&nbsp;2.</p>
<p>Next, we show that: <span id="eq-expected_01_wrt_D"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%5D%20=%20%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%20+%20%5Ctextcolor%7BBrown%7D%7Bc%7D%20%5C,%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7By%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%5D.%0A%5Ctag%7B3%7D"></span></p>
<p>If <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D">, then Equation&nbsp;3 is trivially true with <img src="https://latex.codecogs.com/png.latex?%5Ctextcolor%7BBrown%7D%7Bc%7D%20=%201">. If <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D">, then <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D%20%5Cneq%20%5Cmathbf%7By%7D"> implies that <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D"> and vice-versa. Thus, the expected 0-1 loss w.r.t. different training set can be expressed as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%5D%20&amp;%20=%20p(%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D)%20=%201%20-%20p(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D)%20=%201%20-%20p(%5Cmathbf%7By%7D_%7Bm%7D%20%5Cneq%20%5Cmathbf%7By%7D)%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%201%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D%20=%20%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>Thus, it proves Equation&nbsp;3.</p>
<p>Finally, we can combine both results in Equation&nbsp;2 and Equation&nbsp;3 to prove the theorem. Taking the expectation w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D"> on both sides of Equation&nbsp;2 gives: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20%5Cright%5D%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20+%20c_%7B0%7D%20%5C,%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20+%20c_%7B0%7D%20%5C,%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>And since: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Bc_%7B0%7D%5D%20&amp;%20=%20p(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D)%20-%20p%20(%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D%20=%202%20p(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D)%20-%201,%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> we can then obtain the result of the theorem by using Equation&nbsp;3.</p>
</div>
</details>
</section>
<section id="multi-class-classification" class="level3" data-number="4.2">
<h3 data-number="4.2" class="anchored" data-anchor-id="multi-class-classification"><span class="header-section-number">4.2</span> Multi-class classification</h3>
<div id="thm-multiclass-0-1-loss" class="theorem">
<p><span class="theorem-title"><strong>Theorem 3</strong></span> The expected loss for 0-1 loss in a multiclass classification can be decomposed into: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20%5Cright%5D%20=%20%5Ctextcolor%7BCrimson%7D%7B%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%7D%20+%20%5Ctextcolor%7BBlue%7D%7Bc%7D%20%5C,%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cleft%5B%20%5Cell(%5Cmathbf%7By%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%20%5Cright%5D%7D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20%5B%202%20%5Coperatorname%7BPr%7D_%7B%5Cmathcal%7BD%7D%7D%20(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D)%20-%20%5Coperatorname%7BPr%7D_%7B%5Cmathcal%7BD%7D%7D%20(%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D)%20%5Coperatorname%7BPr%7D_%7B%5Cmathbf%7Bt%7D%7D(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%20%5D%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%20%5D%7D,%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> where: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20c%20=%20%5Cbegin%7Bcases%7D%0A%20%20%20%20%20%20%20%20+1%20&amp;%20%5Ctext%7Bif%20%7D%20%5Cmathbf%7By%7D_%7Bm%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D%5C%5C%0A%20%20%20%20%20%20%20%20-%20%5Coperatorname%7BPr%7D_%7B%5Cmathcal%7BD%7D%7D%20(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D%20%7C%20%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7Bm%7D)%20&amp;%20%5Ctext%7Botherwise%7D.%0A%20%20%20%20%5Cend%7Bcases%7D%0A"></p>
</div>
<details>
<summary>
The proof is copied in <span class="citation" data-cites="domingos2000unified">(Domingos 2000 - Theorem 3)</span> for a self-contained discussion.
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The proof is similar to the binary classification where we decompose <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D">. The key difference is that when <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D"> no longer imply that <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D">. Similarly, <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D%20%5Cneq%20%5Cmathbf%7By%7D"> no longer imply <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D">.</p>
<p>Now, we want to prove the following decomposition: <span id="eq-expected_01_wrt_t_multiclass"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20=%20%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%20+%20c_%7B0%7D%20%5C,%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D,%0A%5Ctag%7B4%7D"></span> where: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20c_%7B0%7D%20=%20%5Cbegin%7Bcases%7D%0A%20%20%20%20%20%20%20%20-p(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%20&amp;%20%5Ctext%7Bwhen%20%7D%20%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D%5C%5C%0A%20%20%20%20%20%20%20%201%20&amp;%20%5Ctext%7Bwhen%20%7D%20%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D.%0A%20%20%20%20%5Cend%7Bcases%7D%0A"></p>
<p>When <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20=%20%5Cmathbf%7By%7D_%7B*%7D">, Equation&nbsp;4 is trivially true with <img src="https://latex.codecogs.com/png.latex?c_%7B0%7D%20=%201">.</p>
<p>When <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D">, the following fact is true: <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D%7C%20%5Cmathbf%7By%7D_%7B*%7D%20=%20%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D)%20=%200">. To simplify the notation, the condition <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7By%7D_%7B*%7D"> is omitted. Thus, applying the sum rule on the probability of predicted label gives: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D)%20&amp;%20=%20%5Cunderbrace%7B%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7By%7D_%7B*%7D%20=%20%5Cmathbf%7Bt%7D)%7D_%7B0%7D%20%5C,%20%5CPr(%5Cmathbf%7By%7D_%7B*%7D%20+%20%5Cmathbf%7Bt%7D)%20+%20%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%20%5C,%20%5CPr(%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%20%5C,%20%5CPr(%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D).%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
<p>The expected loss w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D"> can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20&amp;%20=%20%5CPr(%5Cmathbf%7By%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%20=%201%20-%20%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D)%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%201%20%5Cunderbrace%7B-%20%5CPr(%5Cmathbf%7By%7D%20=%20%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%7D_%7Bc_%7B0%7D%7D%20%5C,%20%5CPr(%5Cmathbf%7By%7D_%7B*%7D%20%5Cneq%20%5Cmathbf%7Bt%7D)%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%20+%20c_%7B0%7D%20%5C,%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cell(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> This proves Equation&nbsp;4.</p>
<p>Similarly, one can prove the decomposition for the expected loss w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D">: <span id="eq-expected_01_wrt_D_multiclass"><img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%5D%20=%20%5Cell(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%20+%20%5Ctextcolor%7BBrown%7D%7Bc%7D%20%5C,%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cell(%5Cmathbf%7By%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%5D.%0A%5Ctag%7B5%7D"></span></p>
<p>Combining the results in Equation&nbsp;4 and Equation&nbsp;5 in a similar manner in the case of binary classification completes the proof.</p>
</div>
</details>
</section>
</section>
<section id="bregman-divergence" class="level2" data-number="5">
<h2 data-number="5" class="anchored" data-anchor-id="bregman-divergence"><span class="header-section-number">5</span> Bregman divergence</h2>
<p>The derivation and discussion in this section is extracted from <span class="citation" data-cites="pfau2025generalized">(Pfau 2025)</span> with some modification to make notations consistent.</p>
<div id="def-bregman-divergence" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 7</strong></span> If <img src="https://latex.codecogs.com/png.latex?F:%20%5Cmathcal%7BY%7D%20%5Cto%20%5Cmathbb%7BR%7D"> is a strictly convex differentiable function, then Bregman divergence derived from <img src="https://latex.codecogs.com/png.latex?F"> is a function <img src="https://latex.codecogs.com/png.latex?D_%7BF%7D:%20%5Cmathcal%7BY%7D%20%5Ctimes%20%5Cmathcal%7BY%7D%20%5Cto%20%5Cmathbb%7BR%7D_%7B+%7D"> defined as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20D_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20=%20F(%5Cmathbf%7Bt%7D)%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5C,%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D).%0A"></p>
</div>
<div class="proof remark">
<p><span class="proof-title"><em>Remark</em>. </span>Given the defintion, Bregman divergence is not symmetric. It does not satisfy the triangle inequality. Thus, it is not a metric.</p>
</div>
<p>Some examples of Bregman divergence:</p>
<ul>
<li>Squared Euclidean distance or square loss: <img src="https://latex.codecogs.com/png.latex?D_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20=%20%7C%7C%20%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D%20%7C%7C_%7B2%7D%5E%7B2%7D"> which is derived from the convex function <img src="https://latex.codecogs.com/png.latex?F(%5Cmathbf%7By%7D)%20=%20%7C%7C%20%5Cmathbf%7By%7D%20%7C%7C_%7B2%7D%5E%7B2%7D"></li>
<li>The squared Mahalanobis distance: <img src="https://latex.codecogs.com/png.latex?%0A%20%20D_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20=%20%5Cfrac%7B1%7D%7B2%7D%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D)%5E%7B%5Ctop%7D%20%5Cmathbf%7BQ%7D%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D)%0A"> which is generated from the convex function: <img src="https://latex.codecogs.com/png.latex?F(%5Cmathbf%7By%7D)%20=%20%5Cfrac%7B1%7D%7B2%7D%20%5Cmathbf%7By%7D%5E%7B%5Ctop%7D%20%5Cmathbf%7BQ%7D%20%5Cmathbf%7By%7D"></li>
<li>The KL divergence: <img src="https://latex.codecogs.com/png.latex?%0A%20%20D_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%20=%20%5Cmathrm%7BKL%7D%20%5B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Cmathbf%7By%7D%5D%20=%20%5Csum_%7Bc%20=%201%7D%5E%7BC%7D%20%5CPr(%5Cmathbf%7Bt%7D%20=%20%5Cmathrm%7Bone-hot%7D(c)%20%7C%20%5Cmathbf%7Bx%7D)%20%5Cfrac%7B%5CPr(%5Cmathbf%7Bt%7D%20=%20%5Cmathrm%7Bone-hot%7D(c)%20%7C%20%5Cmathbf%7Bx%7D)%7D%7B%5Cmathbf%7By%7D_%7Bc%7D%7D%0A"> which is generated from the negative entropy: <img src="https://latex.codecogs.com/png.latex?%0A%20%20F(%5Cmathbf%7By%7D)%20=%20%5Csum_%7Bc%20=%201%7D%5E%7BC%7D%20%5Cmathbf%7By%7D_%7Bc%7D%20%5Cln%20%5Cmathbf%7By%7D_%7Bc%7D.%0A"></li>
</ul>
<section id="some-properties-of-bregman-divergence" class="level3" data-number="5.1">
<h3 data-number="5.1" class="anchored" data-anchor-id="some-properties-of-bregman-divergence"><span class="header-section-number">5.1</span> Some properties of Bregman divergence</h3>
<p>This sub-section presents some properties of Bregman divergence, which can then be used in the bias-variance decomposition. Note that the notation <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D"> used in this section do not need to be label distribution, but can simply be the output of a model (without any normalization, e.g.&nbsp;no <em>softmax</em>). The case for label distributions will be considered in the subsequent section where the loss function is KL divergence.</p>
<div id="lem-bregman-mean-prediction" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 1</strong></span> (Part 1 of Lemma 0.1 in <span class="citation" data-cites="pfau2025generalized">(Pfau 2025)</span>) The <em>mean prediction</em> for Bregman divergence with <strong>un-bounded support</strong> has the following property: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbf%7By%7D_%7Bm%7D%20=%20%5Coperatorname*%7Bargmin%7D_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D)%5D%20%5CLeftrightarrow%20%5Cnabla%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cnabla%20F(%5Cmathbf%7By%7D)%20%5D.%0A"></p>
</div>
<details>
<summary>
Detailed proof
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>&nbsp;</p>
<section id="necessary" class="level4" data-number="5.1.1">
<h4 data-number="5.1.1" class="anchored" data-anchor-id="necessary"><span class="header-section-number">5.1.1</span> Necessary</h4>
<p>When <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D"> is a minimizer of <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D)%5D"> w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%5E%7B%5Cprime%7D">, the necessary condition of such statement is that its gradient is zero: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cnabla_%7B%5Cmathbf%7By%7D_%7Bm%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D%20&amp;%20=%20%5Cnabla_%7B%5Cmathbf%7By%7D_%7Bm%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5C,%20(%5Cmathbf%7By%7D_%7Bm%7D%20-%20%5Cmathbf%7By%7D)%20%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cnabla_%7B%5Cmathbf%7By%7D_%7Bm%7D%7D%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20-%20%5Cnabla_%7B%5Cmathbf%7By%7D_%7Bm%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5C,%20%5Cmathbf%7By%7D_%7Bm%7D%5D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cnabla_%7B%5Cmathbf%7By%7D_%7Bm%7D%7D%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cnabla%20F(%5Cmathbf%7By%7D)%20%5D%20=%200.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cimplies%20%5Cnabla%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cnabla%20F(%5Cmathbf%7By%7D)%20%5D.%0A"></p>
</section>
<section id="sufficient" class="level4" data-number="5.1.2">
<h4 data-number="5.1.2" class="anchored" data-anchor-id="sufficient"><span class="header-section-number">5.1.2</span> Sufficient</h4>
<p>Similar to the necessary condition, one can easily show that <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Cmathbf%7By%7D_%7Bm%7D%7D%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cnabla%20F(%5Cmathbf%7By%7D)%20%5D"> implies that <img src="https://latex.codecogs.com/png.latex?%5Cnabla_%7B%5Cmathbf%7By%7D_%7Bm%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D%20=%200"> (assume that <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D"> is independent from <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D">). And since <img src="https://latex.codecogs.com/png.latex?D_%7BF%7D"> is convex in its first argument <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D"> (one property of Bregman divergence), <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D"> is unique and the minimizer of <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D)%5D">.</p>
</section>
<section id="note" class="level4" data-number="5.1.3">
<h4 data-number="5.1.3" class="anchored" data-anchor-id="note"><span class="header-section-number">5.1.3</span> Note</h4>
<p>The lemma only holds for Bregman divergence with <b>un-bounded support</b>, e.g.&nbsp;<img src="https://latex.codecogs.com/png.latex?F"> is MSE. Otherwise, the gradient of <img src="https://latex.codecogs.com/png.latex?%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D"> w.r.t. the first argument would not be zero, but the Lagrangean that consists of the additional constraints would. This will be presented in the subsequent section where the loss function is the KL divergence.</p>
</section>
</div>
</details>
<div id="lem-bregman-optimal-prediction" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 2</strong></span> (Part 2 of Lemma 0.1 in <span class="citation" data-cites="pfau2025generalized">(Pfau 2025)</span>) The <em>optimal prediction</em> of Bregman divergence can be expressed as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbf%7By%7D_%7B*%7D%20=%20%5Coperatorname*%7Bargmin%7D_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%5D%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D.%0A"></p>
</div>
<details>
<summary>
Detailed proof
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The proof is quite straight-forward. One can calculate the gradient and solve for the root of the gradient as follows: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cnabla_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%5D%20&amp;%20=%20%5Cnabla_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20F(%5Cmathbf%7Bt%7D)%20-%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20%5C,%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20%5D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20-%20%5Cnabla_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20-%20%5Cnabla%5E%7B2%7D%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20%5Ctimes%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20+%20%5Cnabla%5E%7B2%7D%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20%5Ctimes%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20+%20%5Cnabla_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cnabla%5E%7B2%7D%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%20=%200%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> And since <img src="https://latex.codecogs.com/png.latex?F(.)"> is strictly convex, its Hessian matrix <img src="https://latex.codecogs.com/png.latex?%5Cnabla%5E%7B2%7D%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)"> is positive definite and invertible. Hence, one can imply that: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D.%0A"></p>
</div>
</details>
<div id="lem-expected-bregman-divergence" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 3</strong></span> (Part 1 of Theorem 0.1 in <span class="citation" data-cites="pfau2025generalized">(Pfau 2025)</span>) The expected Bregman divergences w.r.t. the set of training sets <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D"> have the following exact decomposition: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20D_%7BF%7D%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D)%5D%20=%20D_%7BF%7D(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%20+%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D,%0A"> where: <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D%20=%20%5Coperatorname*%7Bargmin%7D_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D)%5D"> is the <em>mean prediction</em> of the model of interest, and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D%5E%7B%5Cprime%7D"> is a (random) prediction that is independent from <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D">.</p>
</div>
<details>
<summary>
Detailed proof
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The is quite straight-forward: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20D_%7BF%7D(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%20+%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20-%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20%5Ctimes%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20-%20%5Cmathbf%7By%7D_%7Bm%7D)%20+%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BF(%5Cmathbf%7By%7D_%7Bm%7D)%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5Ctimes%20(%5Cmathbf%7By%7D_%7Bm%7D%20-%20%5Cmathbf%7By%7D)%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D_%7Bm%7D)%20%5Ctimes%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20-%20%5Cmathbf%7By%7D_%7Bm%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20F(%5Cmathbf%7By%7D)%20+%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5Ctimes%20(%5Cmathbf%7By%7D_%7Bm%7D%20-%20%5Cmathbf%7By%7D)%5D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5D%20%5Ctimes%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20-%20%5Cmathbf%7By%7D_%7Bm%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20F(%5Cmathbf%7By%7D)%20+%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5Ctimes%20(%5Cmathbf%7By%7D_%7Bm%7D%20-%20%5Cmathbf%7By%7D)%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20F(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5D%20%5Ctimes%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%20-%20%5Cmathbf%7By%7D_%7Bm%7D%20+%20%5Cmathbf%7By%7D_%7Bm%7D%20-%20%5Cmathbf%7By%7D)%20%5D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20D_%7BF%7D%20(%5Cmathbf%7By%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D)%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> The third inequality is due to Lemma&nbsp;1.</p>
</div>
</details>
<div id="lem-expected-bregman-divergence-t" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 4</strong></span> (Part 2 of Theorem 0.1 in <span class="citation" data-cites="pfau2025generalized">(Pfau 2025)</span>) The expected Bregman divergences w.r.t. the underlying label distribution <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)"> have the following exact decomposition: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20D_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20=%20D_%7BF%7D(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%20+%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D,%0A"> where <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7B*%7D%20=%20%5Coperatorname*%7Bargmin%7D_%7B%5Cmathbf%7By%7D%5E%7B%5Cprime%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D%5E%7B%5Cprime%7D)%5D%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D"> is the <em>optimal prediction</em> in Lemma&nbsp;2.</p>
</div>
<details>
<summary>
Detailed proof
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The proof is quite straight-forward: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20D_%7BF%7D(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%20+%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20F(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5Ctimes%20(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20-%20%5Cmathbf%7By%7D)%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20+%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BF(%5Cmathbf%7Bt%7D)%20-%20F(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%20%5Ctimes%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5Ctimes%20(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20-%20%5Cmathbf%7By%7D)%20+%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BF(%5Cmathbf%7Bt%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%20%5Ctimes%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D)%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5Ctimes%20(%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%5Cmathbf%7Bt%7D%5D%20-%20%5Cmathbf%7By%7D)%20+%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BF(%5Cmathbf%7Bt%7D)%5D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20F(%5Cmathbf%7Bt%7D)%20-%20F(%5Cmathbf%7By%7D)%20-%20%5Cnabla%5E%7B%5Ctop%7D%20F(%5Cmathbf%7By%7D)%20%5Ctimes%20(%5Cmathbf%7Bt%7D%20-%20%5Cmathbf%7By%7D)%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20D_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"></p>
</div>
</details>
</section>
<section id="decomposition-for-bregman-divergence" class="level3" data-number="5.2">
<h3 data-number="5.2" class="anchored" data-anchor-id="decomposition-for-bregman-divergence"><span class="header-section-number">5.2</span> Decomposition for Bregman divergence</h3>
<p>The main result of bias-variance decomposition can be shown in the following:</p>
<div id="thm-decomposition-bregma-div" class="theorem">
<p><span class="theorem-title"><strong>Theorem 4</strong></span> The expected Bregman divergence on a set of training set <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D"> can be decomposed into: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20=%20%5Ctextcolor%7BCrimson%7D%7BD_%7BF%7D%20(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%7D%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D%7D%20+%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cleft%5B%20D_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%20%5Cright%5D%7D.%0A"></p>
</div>
<details>
<summary>
Detailed proof
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The proof is a consequence of the previous lemma: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D%20(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D)%5D%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%20+%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D%7D%20%5D%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%5B%20D_%7BF%7D(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D)%5D%20+%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D%7D%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20=%20%5Ctextcolor%7BCrimson%7D%7BD_%7BF%7D%20(%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D_%7Bm%7D)%7D%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5BD_%7BF%7D(%5Cmathbf%7By%7D_%7Bm%7D,%20%5Cmathbf%7By%7D)%5D%7D%20+%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5BD_%7BF%7D(%5Cmathbf%7Bt%7D,%20%5Cmathbf%7By%7D_%7B*%7D)%5D%7D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> The first equality is due to Lemma&nbsp;4 and the last equality of the above equation is due to Lemma&nbsp;3.</p>
</div>
</details>
<section id="square-loss-1" class="level4" data-number="5.2.1">
<h4 data-number="5.2.1" class="anchored" data-anchor-id="square-loss-1"><span class="header-section-number">5.2.1</span> Square loss</h4>
<p>As MSE or square loss is a special instance of Bregman divergence, one can apply Theorem&nbsp;4 to obtain the result for MSE as shown in Theorem&nbsp;1.</p>
</section>
</section>
</section>
<section id="kullback-leibler-divergence" class="level2" data-number="6">
<h2 data-number="6" class="anchored" data-anchor-id="kullback-leibler-divergence"><span class="header-section-number">6</span> Kullback-Leibler divergence</h2>
<p>KL divergence is a special case of Bregman divergence. However, the analysis done for the Bregman divergence presented in this post is considered on <em>un-bounded</em> support, where the support space for the KL divergence is the probability space. In addition, KL divergence is used to measure the difference between 2 distributions. Such differences result in a different in terms of bias-variance decomposition.</p>
<p>In this section, <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7B*%7D,%20%5Cmathbf%7By%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bm%7D"> are label distributions or probabilities. They will be replaced by <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D),%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)"> and <img src="https://latex.codecogs.com/png.latex?%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)">, respectively, to make the formulation easier to understand.</p>
<div id="lem-kl-div-model-prediction" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 5</strong></span> (Main model prediction - Eq. (2.3) in <span class="citation" data-cites="heskes1998bias">(Heskes 1998)</span>) The main model prediction when the loss is the KL divergence has the following property: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20=%20%5Coperatorname*%7Bargmin%7D_%7Bq(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cmathrm%7BKL%7D%20%5Bq(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%5D%20%5CRightarrow%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20=%20%5Cfrac%7B1%7D%7BZ%7D%20%5Cexp%20%5Cleft%5B%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%20%5Cright%5D,%0A"> where <img src="https://latex.codecogs.com/png.latex?Z"> is a normalization constant independent of model prediction <img src="https://latex.codecogs.com/png.latex?%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)">.</p>
</div>
<details>
<summary>
Detailed proof
</summary>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The proof is similar to Lemma&nbsp;1, except the constraint <img src="https://latex.codecogs.com/png.latex?%5Csum_%7B%5Cmathbf%7Bt%7D%7D%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20=%201"> is taken into account. More specifically, the Lagrangean can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathsf%7BL%7D%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cmathrm%7BKL%7D%20%5B%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%5D%20+%20%5Clambda%20(%5Cpmb%7B1%7D%5E%7B%5Ctop%7D%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20-%201),%0A"> where <img src="https://latex.codecogs.com/png.latex?%5Clambda"> is the Lagrange multiplier.</p>
<p>At the optimal point, the gradient of the Lagrangean is zero: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Cnabla_%7B%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cmathsf%7BL%7D%20&amp;%20=%20%5Cln%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20+%20%5Clambda%20=%200%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5CRightarrow%20%5Cln%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20-%20%5Clambda%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5CRightarrow%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20=%20%5Cunderbrace%7B%5Cfrac%7B1%7D%7B%5Cexp(%5Clambda)%7D%7D_%7B%5Cfrac%7B1%7D%7BZ%7D%7D%20%5Cexp%5B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%5D.%0A%20%20%20%20%5Cend%7Baligned%7D%0A"> Actually, the normalization constant <img src="https://latex.codecogs.com/png.latex?Z"> is the negative variance: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cln%20Z%20%5Ctimes%20%5Cpmb%7B1%7D%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20-%20%5Cln%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cleft%5B%20%5Cln%20%5Cfrac%7B%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%7B%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cright%5D.%0A"> Note that: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cln%20Z%20=%20%5Cmathbb%7BE%7D_%7B%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20%5Cln%20Z%20%5Ctimes%20%5Cpmb%7B1%7D%5D.%0A"> Thus: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%20%20%20%20%5Cln%20Z%20=%20%5Cmathbb%7BE%7D_%7B%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cleft%5B%20%5Cln%20%5Cfrac%7B%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%7B%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5Cright%5D%20=%20-%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cleft%5B%20%5Cmathrm%7BKL%7D%20%5B%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%20%5Cright%5D%7D.%0A"></p>
</div>
</details>
<div id="thm-kl-decomposition" class="theorem">
<p><span class="theorem-title"><strong>Theorem 5</strong></span> (Decomposition for KL divergence) The bias-variance decomposition for KL divergence can be presented as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cmathrm%7BKL%7D%20%5B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%20%5D%20=%20%5Ctextcolor%7BCrimson%7D%7B%5Cmathrm%7BKL%7D%20%5B%20%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%7D%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cmathrm%7BKL%7D%20%5B%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20%5D%7D.%0A"></p>
</div>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>The proof is quite straight-forward from Lemma&nbsp;5.</p>
</div>
<p>The result in Theorem&nbsp;5 does not consist of an intrinsic noise since the loss defined by KL divergence is based on the true label distribution instead of each sample <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D">. To obtain the wellknown form of bias-variance decomposition based on label <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D">, the negative log likelihood <img src="https://latex.codecogs.com/png.latex?-%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)"> is used as the loss function. Note that <img src="https://latex.codecogs.com/png.latex?%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)"> is still defined with KL divergence as the loss function.</p>
<p>From Lemma&nbsp;5, one can obtain: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20-%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20=%20-%5Cln%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cmathrm%7BKL%7D%20%5B%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20%5D%7D.%0A"></p>
<p>Thus, the negative log-likelihood can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20-%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20=%20-%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20%5Cln%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cmathrm%7BKL%7D%20%5B%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20%5D%7D.%0A"></p>
<p>Or: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%20%5B%20-%5Cln%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20=%20%5Ctextcolor%7BCrimson%7D%7B%5Cmathrm%7BKL%7D%5B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%7D%20+%20%5Ctextcolor%7BMidnightBlue%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%7D%20%5B%20%5Cmathrm%7BKL%7D%20%5B%20%5Coperatorname%7BPr%7D_%7Bm%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%7C%7C%20%5Chat%7Bp%7D(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%20%5D%20%5D%7D%20+%20%5Ctextcolor%7BGreen%7D%7B%5Cmathbb%7BE%7D_%7B%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%7D%5B-%5Cln%20%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)%5D%7D.%0A"></p>
<p>The bias-variance decomposition for negative log-likelihood in this case consists of an intrinsic noise term which equals to the Shannon entropy of the true label distribution <img src="https://latex.codecogs.com/png.latex?%5CPr(%5Cmathbf%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D)">.</p>
</section>
<section id="conclusion" class="level2" data-number="7">
<h2 data-number="7" class="anchored" data-anchor-id="conclusion"><span class="header-section-number">7</span> Conclusion</h2>
<p>In general, the bias - variance decomposition might not be always in the form of bias, variance and noise as commonly seen in MSE. Here, we show that different loss function might have a different decomposition. Nevertheless, the two most common loss functions, i.e., MSE and KL divergence, share a similar form. Note that, one needs to be careful when applying such bias - variance decomposition due to their difference in terms of <em>main model prediction</em> and <em>optimal label</em>.</p>
</section>
<section id="references" class="level2" data-number="8">
<h2 data-number="8" class="anchored" data-anchor-id="references"><span class="header-section-number">8</span> References</h2>
<div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-domingos2000unified" class="csl-entry">
Domingos, Pedro. 2000. <span>“A Unified Bias-Variance Decomposition.”</span> <em>International Conference on Machine Learning</em>, 231–38.
</div>
<div id="ref-heskes1998bias" class="csl-entry">
Heskes, Tom. 1998. <span>“Bias/Variance Decompositions for Likelihood-Based Estimators.”</span> <em>Neural Computation</em> 10 (6): 1425–33.
</div>
<div id="ref-pfau2025generalized" class="csl-entry">
Pfau, David. 2025. <span>“A Generalized Bias-Variance Decomposition for Bregman Divergences.”</span> <em>arXiv Preprint arXiv:2511.08789</em>.
</div>
</div>


<!-- -->

</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a><div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by/4.0/">CC BY 4.0</a></div></div></section><section class="quarto-appendix-contents" id="quarto-citation"><h2 class="anchored quarto-appendix-heading">Citation</h2><div><div class="quarto-appendix-secondary-label">BibTeX citation:</div><pre class="sourceCode code-with-copy quarto-appendix-bibtex"><code class="sourceCode bibtex">@online{nguyen2022,
  author = {Nguyen, Cuong},
  title = {Bias - Variance Decomposition},
  date = {2022-05-03},
  url = {https://cnguyen10.github.io/posts/bias-variance-decomposition/},
  langid = {en}
}
</code></pre><div class="quarto-appendix-secondary-label">For attribution, please cite this work as:</div><div id="ref-nguyen2022" class="csl-entry quarto-appendix-citeas">
Nguyen, Cuong. 2022. <span>“Bias - Variance Decomposition.”</span> May
3. <a href="https://cnguyen10.github.io/posts/bias-variance-decomposition/">https://cnguyen10.github.io/posts/bias-variance-decomposition/</a>.
</div></div></section></div> ]]></description>
  <category>Machine Learning</category>
  <category>Statistics</category>
  <guid>https://cnguyen10.github.io/posts/bias-variance-decomposition/</guid>
  <pubDate>Tue, 03 May 2022 00:00:00 GMT</pubDate>
</item>
<item>
  <title>From hyper-parameter optimisation to meta-learning</title>
  <dc:creator>Cuong Nguyen</dc:creator>
  <link>https://cnguyen10.github.io/posts/meta-learning/</link>
  <description><![CDATA[ 




<p>Meta-learning, also known as <em>learn-how-to-learning</em>, has been being studied from 1980s <span class="citation" data-cites="schmidhuber1987evolutionary naik1992meta">(Schmidhuber 1987; Naik and Mammone 1992)</span>, and recently attracted much attention from the research community. Meta-learning is a technique in <em>transfer learning</em> — a learning paradigm that utilises knowledge gained from past experience to facilitate the learning in the future. Due to being defined <q>implicitly</q>, meta -learning is often confused with other transfer learning techniques, e.g.&nbsp;<em>fine-tuning</em>, <em>multi-task learning</em>, <em>domain adaptation</em> and <em>continual learning</em>. The purpose of this post is to formulate meta-learning explicitly via <em>empirical Bayes</em>, and in particular <em>hyper-parameter optimisation</em>, to differentiate meta-learning from those common transfer learning approaches.</p>
<p>This post is structured as follows: First, we define some terminologies used in general transfer learning and review hyper-parameter optimisation in single-task setting. We then formulate meta-learning as an extension of hyper-parameter optimisation in multi-task setting. Finally, we show the differences between meta-learning and other transfer-learning approaches.</p>
<section id="background" class="level2" data-number="1">
<h2 data-number="1" class="anchored" data-anchor-id="background"><span class="header-section-number">1</span> Background</h2>
<section id="data-generation-model-of-a-task" class="level3" data-number="1.1">
<h3 data-number="1.1" class="anchored" data-anchor-id="data-generation-model-of-a-task"><span class="header-section-number">1.1</span> Data generation model of a task</h3>
<p>A data point of a task indexed by <img src="https://latex.codecogs.com/png.latex?i%20%5Cin%20%5Cmathbb%7BN%7D"> consists of an input <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D_%7Bij%7D%20%5Cin%20%5Cmathcal%7BX%7D%20%5Csubseteq%20%5Cmathbb%7BR%7D%5E%7Bd%7D"> and a corresponding label <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bij%7D%20%5Cin%20%5Cmathcal%7BY%7D"> with <img src="https://latex.codecogs.com/png.latex?j%20%5Cin%20%5Cmathbb%7BN%7D">. For simplicity, only two families of tasks – regression and classification – are considered in this thesis. As a result, the label is defined as <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BY%7D%20%5Csubseteq%20%5Cmathbb%7BR%7D"> for regression and as <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BY%7D%20=%20%5C%7B0,%201,%20%5Cldots,%20C%20-%201%5C%7D"> for classification, where <img src="https://latex.codecogs.com/png.latex?C"> is the number of classes.</p>
<p>Each data point in a task can be generated in 2 steps:</p>
<ul>
<li>generate the input <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D_%7Bij%7D"> by sampling from some probability distribution <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D_%7Bi%7D">,</li>
<li>determine the label <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bij%7D%20=%20f(%5Cmathbf%7Bx%7D_%7Bij%7D)">, where <img src="https://latex.codecogs.com/png.latex?f_%7Bi%7D:%20%5Cmathcal%7BX%7D%20%5Cto%20%5Cmathcal%7BY%7D"> is the <q>correct</q> labelling function.</li>
</ul>
<p>Both the probability distribution <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D_%7Bi%7D"> and the labelling function <img src="https://latex.codecogs.com/png.latex?f_%7Bi%7D"> are unknown to the learning agent during training, and the aim of the supervised learning is to use the generated data to infer such labelling function <img src="https://latex.codecogs.com/png.latex?f">.</p>
<p>For simplicity, we denote <img src="https://latex.codecogs.com/png.latex?(%5Cmathbf%7Bx%7D_%7Bij%7D,%20%5Cmathbf%7By%7D_%7Bij%7D)%20%5Csim%20(%5Cmathcal%7BD%7D_%7Bi%7D,%20f_%7Bi%7D)"> as the data generation model of task <img src="https://latex.codecogs.com/png.latex?i">-th.</p>
</section>
<section id="task-instance" class="level3" data-number="1.2">
<h3 data-number="1.2" class="anchored" data-anchor-id="task-instance"><span class="header-section-number">1.2</span> Task instance</h3>
<div id="def-task-instance" class="theorem definition">
<p><span class="theorem-title"><strong>Definition 1</strong></span> <span class="citation" data-cites="hospedales2021meta">(Hospedales et al. 2021)</span></p>
<p>A <em>task</em> or a <em>task instance</em> <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7Bi%7D"> consists of an unknown associated data generation model <img src="https://latex.codecogs.com/png.latex?(%5Cmathcal%7BD%7D_%7Bi%7D,%20f_%7Bi%7D)">, and a loss function <img src="https://latex.codecogs.com/png.latex?%5Cell_%7Bi%7D">, denoted as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathcal%7BT%7D_%7Bi%7D%20=%20%5C%7B(%5Cmathcal%7BD%7D_%7Bi%7D,%20f_%7Bi%7D),%20%5Cell_%7Bi%7D%5C%7D.%0A"></p>
</div>
<div class="proof remark">
<p><span class="proof-title"><em>Remark</em>. </span>The loss function <img src="https://latex.codecogs.com/png.latex?%5Cell_%7Bi%7D"> is defined abstractly, and can be either:</p>
<ul>
<li><p>negative log-likelihood (NLL): <img src="https://latex.codecogs.com/png.latex?-%20%5Cln%20p(y_%7Bij%7D%20%7C%20%5Cmathbf%7Bx%7D_%7Bij%7D,%20%5Cmathbf%7Bw%7D_%7Bi%7D)">, corresponding to maximum likelihood estimation. This type of loss is quite common in practice, for example:</p>
<ul>
<li>mean squared error (MSE) in regression</li>
<li>cross-entropy in classification</li>
</ul></li>
<li><p>variational-free energy (negative <em>evidence lower-bound</em>) — corresponding to the objective function in variational inference.</p></li>
</ul>
</div>
<p>To solve a task <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7Bi%7D">, one needs to obtain an optimal task-specific model <img src="https://latex.codecogs.com/png.latex?%7Bh(.;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D):%20%5Cmathcal%7BX%7D%20%5Cto%20%5Cmathcal%7BY%7D%7D">, parameterised by <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D%5E%7B*%7D_%7Bi%7D%20%5Cin%20%5Cmathcal%7BW%7D%20%5Csubseteq%20%5Cmathbb%7BR%7D%5E%7Bn%7D">, which minimises a loss function <img src="https://latex.codecogs.com/png.latex?%5Cell_%7Bi%7D"> on the data of that task: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D%20=%20%5Carg%5Cmin_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%7D%20%5Cmathbb%7BE%7D_%7B(%5Cmathbf%7Bx%7D_%7Bij%7D,%20%5Cmathbf%7By%7D_%7Bij%7D)%20%5Csim%20(%5Cmathcal%7BD%7D_%7Bi%7D,%20f_%7Bi%7D)%7D%20%5Cleft%5B%20%5Cell_%7Bi%7D%20(%5Cmathbf%7Bx%7D_%7Bij%7D,%20%5Cmathbf%7By%7D_%7Bij%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D)%20%5Cright%5D.%0A"></p>
<p>In practice, since both <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D_%7Bi%7D"> and <img src="https://latex.codecogs.com/png.latex?f_%7Bi%7D"> are unknown, the data generation model is replaced by a dataset consisting of a finite number of data-points generated according to the data generation model <img src="https://latex.codecogs.com/png.latex?(%5Cmathcal%7BD%7D_%7Bi%7D,%20f_%7Bi%7D)">, denoted as <img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D%20=%20%5C%7B%5Cmathbf%7Bx%7D_%7Bij%7D,%20%5Cmathbf%7By%7D_%7Bij%7D%5C%7D_%7Bj=1%7D%5E%7Bm_%7Bi%7D%7D">. The objective to solve that task is often known as empirical risk minimisation: <span id="eq-objective_minimise_loss"><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbf%7Bw%7D%5E%7B%5Cmathrm%7BERM%7D%7D_%7Bi%7D%20=%20%5Carg%5Cmin_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%7D%20%5Cfrac%7B1%7D%7Bm_%7Bi%7D%7D%20%5Csum_%7Bj%20=%201%7D%5E%7Bm_%7Bi%7D%7D%20%5Cleft%5B%20%5Cell_%7Bi%7D%20(%5Cmathbf%7Bx%7D_%7Bij%7D,%20%5Cmathbf%7By%7D_%7Bij%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D)%20%5Cright%5D.%0A%5Ctag%7B1%7D"></span></p>
<p>Since the loss function used is the same for each task family, e.g.&nbsp;<img src="https://latex.codecogs.com/png.latex?%5Cell"> is NLL or variational-free energy, the subscript on the loss function is, therefore, dropped, and the loss is denoted as <img src="https://latex.codecogs.com/png.latex?%5Cell"> throughout this chapter. Furthermore, given the commonality of the loss function across all tasks, a task can, therefore, be simply represented by either its data generation model <img src="https://latex.codecogs.com/png.latex?(%5Cmathcal%7BD%7D_%7Bi%7D,%20f_%7Bi%7D)"> or the associated dataset <img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D">.</p>
</section>
<section id="hyper-parameter-optimisation" class="level3" data-number="1.3">
<h3 data-number="1.3" class="anchored" data-anchor-id="hyper-parameter-optimisation"><span class="header-section-number">1.3</span> Hyper-parameter optimisation</h3>
<p>In single-task setting, the common way to <q>tune</q> or optimise a hyper-parameter is to split a given dataset <img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D"> into two disjoint subsets: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0AS_%7Bi%7D%5E%7B(t)%7D%20%5Ccup%20S_%7Bi%7D%5E%7B(v)%7D%20&amp;%20=%20S_%7Bi%7D%5C%5C%0AS_%7Bi%7D%5E%7B(t)%7D%20%5Ccap%20S_%7Bi%7D%5E%7B(v)%7D%20&amp;%20=%20%5Cvarnothing,%0A%5Cend%7Baligned%7D%0A"> where:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D%5E%7B(t)%7D%20=%20%5Cleft%5C%7B%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D%20%5Cright)%20%5Cright%5C%7D_%7Bj=1%7D%5E%7Bm_%7Bi%7D%5E%7B(t)%7D%7D"> is the <em>training</em> (or <em>support</em>) subset,</li>
<li><img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D%5E%7B(v)%7D%20=%20%5Cleft%5C%7B%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(v)%7D,%20y_%7Bij%7D%5E%7B(v)%7D%20%5Cright)%20%5Cright%5C%7D_%7Bj=1%7D%5E%7Bm_%7Bi%7D%5E%7B(v)%7D%7D"> is the <em>validation</em> (or <em>query</em>) subset.</li>
</ul>
<p>Note that with this definition, <img src="https://latex.codecogs.com/png.latex?m_%7Bi%7D%5E%7B(t)%7D%20+%20m_%7Bi%7D%5E%7B(v)%7D%20=%20m_%7Bi%7D">, and <img src="https://latex.codecogs.com/png.latex?m_%7Bi%7D%5E%7B(t)%7D"> and <img src="https://latex.codecogs.com/png.latex?m_%7Bi%7D%5E%7B(v)%7D"> are not necessarily identical.</p>
<p>The subset <img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D%5E%7B(t)%7D"> is used to train the model parameter of interest <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bi%7D">, while the subset <img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D%5E%7B(v)%7D"> is used to validate the hyper-parameter, denoted by <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> (we provide examples of the hyper-parameter in Section Formulation of meta-learning). Mathematically, hyper-parameter optimisation in the single-task setting can be written as the following bi-level optimisation: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmin_%7B%5Ctheta%7D%20%5Cfrac%7B1%7D%7Bm_%7Bi%7D%5E%7B(v)%7D%7D%20%5Csum_%7Bk%20=%201%7D%5E%7Bm_%7Bi%7D%5E%7B(v)%7D%7D%20%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D%20(%5Ctheta)%20%5Cright)%5C%5C%0A&amp;%20%5Ctext%7Bs.t.:%20%7D%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D%20(%5Ctheta)%20=%20%5Carg%5Cmin_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%7D%20%5Cfrac%7B1%7D%7Bm_%7Bi%7D%5E%7B(t)%7D%7D%20%5Csum_%7Bj%20=%201%7D%5E%7Bm_%7Bi%7D%5E%7B(t)%7D%7D%20%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%20(%5Ctheta)%20%5Cright).%0A%5Cend%7Baligned%7D%0A"></p>
<p>We can extend the hyper-parameter optimisation from the two data subsets <img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D%5E%7B(t)%7D"> and <img src="https://latex.codecogs.com/png.latex?S_%7Bi%7D%5E%7B(v)%7D"> to the general data generation model as the following: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmin_%7B%5Ctheta%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D%20(%5Ctheta)%20%5Cright)%20%5Cright%5D%5C%5C%0A&amp;%20%5Ctext%7Bs.t.:%20%7D%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D%20(%5Ctheta)%20=%20%5Carg%5Cmin_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(t)%7D,%20y_%7Bik%7D%5E%7B(t)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(t)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%20(%5Ctheta)%20%5Cright)%20%5Cright%5D,%0A%5Cend%7Baligned%7D%0A"> where <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(t)%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D"> are the probability distributions of training and validation input data, respectively, and they are not necessarily identical.</p>
<p>Formulation of meta-learning</p>
<p>The setting of the meta-learning problem considered in this paper follows the <em>task environment</em> <span class="citation" data-cites="baxter2000model">(Baxter 2000)</span> that describes the unknown distribution <img src="https://latex.codecogs.com/png.latex?p(%5Cmathcal%7BD%7D,%20f)"> over a family of tasks. Each task <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7Bi%7D"> is sampled from this task environment and can be represented as <img src="https://latex.codecogs.com/png.latex?%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(t)%7D,%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)">, where <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(t)%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D"> are the probability of training and validation input data, respectively, and are not necessarily identical. The aim of meta-learning is to use <img src="https://latex.codecogs.com/png.latex?T"> training tasks to train a meta-learning model that can be fine-tuned to perform well on an unseen task sampled from the same task environment.</p>
<p>Such meta-learning methods use meta-parameters to model the common latent structure of the task distribution <img src="https://latex.codecogs.com/png.latex?p(%5Cmathcal%7BD%7D,%20f)">. In this thesis, we consider meta-learning as an extension of hyper-parameter optimisation in single-task learning, where the hyper-parameter of interest — often called <em>meta-parameter</em> — is shared across many tasks. Similar to hyper-parameter optimisation presented in subsection hyper-parameter-optimisation, the objective of meta-learning is also a bi-level optimisation: <span id="eq-meta_learning_bilevel_optimisation"><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmin_%7B%5Ctheta%7D%20%20%5Ctextcolor%7Bcrimson%7D%7B%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f_%7Bi%7D%20%5Cright)%7D%7D%20%5Cmathbb%7BE%7D_%7B%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D%5C%5C%0A&amp;%20%5Ctext%7Bs.t.:%20%7D%20%5Cmathbf%7Bw%7D%5E%7B*%7D_%7Bi%7D(%5Ctheta)%20=%20%5Carg%5Cmin_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(t)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D.%0A%5Cend%7Baligned%7D%0A%5Ctag%7B2%7D"></span></p>
<p>The difference between meta-learning and hyper-parameter optimisation is that the meta-parameter (also known as hyper-parameter) <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> is shared across all tasks sampled from the task environment <img src="https://latex.codecogs.com/png.latex?p(%5Cmathcal%7BD%7D,%20f)"> as highlighted in <span style="color: crimson;">red</span> colour in Equation&nbsp;2.</p>
<p>In practice, the meta-parameter (or shared hyper-parameter) <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> can be chosen as one of the followings:</p>
<ul>
<li><em>learning rate</em> of gradient-based optimisation used to minimise the lower level objective function in Equation&nbsp;2 to learn <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D%20%5Cleft(%5Ctheta%5Cright)"> <span class="citation" data-cites="li2017meta">(Li et al. 2017)</span>,</li>
<li><em>initialisation</em> of model parameter <span class="citation" data-cites="finn2017model">(Finn et al. 2017)</span>,</li>
<li><em>data representation</em> or <em>feature extractor</em> <span class="citation" data-cites="vinyals2016matching snell2017prototypical">(<span class="nocase">Vinyals et al.</span> 2016; Snell et al. 2017)</span>,</li>
<li><em>optimiser</em> used to optimise the lower-level in Equation&nbsp;2.</li>
</ul>
<p>In this post, the meta-parameter <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> is assumed to be the initialisation of model parameters. Formulation, derivation and analysis in the subsequent sections and chapters will, therefore, revolve around this assumption. Note that the analysis can be straight-forwardly extended to other types of meta-parameters with slight modifications.</p>
<p>In general, the objective function of meta-learning in Equation&nbsp;2 can be solved by gradient-based optimisation, such as gradient descent. Due to the nature of the bi-level optimisation, the optimisation are often carried out in two steps. The first step is to adapt (or fine-tuned) the meta-parameter <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> to the task-specific parameter <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bi%7D(%5Ctheta)">. This corresponds to the optimisation in the lower-level, and can be written as: <span id="eq-task_adaptation_sgd"><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20=%20%5Ctheta%20-%20%5Calpha%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(t)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cnabla_%7B%5Ctheta%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20%5Cmathbf%7By%7D_%7Bij%7D%5E%7B(t)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D,%0A%5Ctag%7B3%7D"></span> where <img src="https://latex.codecogs.com/png.latex?%5Calpha"> is a hyper-parameter denoting the learning rate for task <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7Bi%7D">. For simplicity, the adaptation step in Equation&nbsp;3} is carried out with only one gradient descent update.</p>
<p>The second step is to minimise the validation loss induced by the locally-optimal task-specific parameter <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)"> evaluated on the validation subset w.r.t. the meta-parameter <img src="https://latex.codecogs.com/png.latex?%5Ctheta">. This corresponds to the upper-level optimisation, and can be expressed as: <span id="eq-meta_parameter_update_sgd"><img src="https://latex.codecogs.com/png.latex?%0A%5Ctheta%20%5Cgets%20%5Ctheta%20-%20%5Cgamma%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p(%5Cmathcal%7BD%7D,%20f)%7D%20%5Cmathbb%7BE%7D_%7B%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20%5Cmathbf%7By%7D_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cnabla_%7B%5Ctheta%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(v)%7D,%20%5Cmathbf%7By%7D_%7Bij%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D,%0A%5Ctag%7B4%7D"></span> where <img src="https://latex.codecogs.com/png.latex?%5Cgamma"> is another hyper-parameter representing the learning rate to learn <img src="https://latex.codecogs.com/png.latex?%5Ctheta">.</p>
<p>The general algorithm of meta-learning using gradient-based optimisation is shown in Algorithm 1.</p>
<div id="algo-meta-learning" class="pseudocode-container quarto-float" data-caption-prefix="Algorithm" data-line-number-punc=":" data-comment-delimiter="//" data-line-number="true" data-no-end="false" data-pseudocode-number="1" data-indent-size="1.2em">
<div class="pseudocode">
\begin{algorithm} \caption{Training procedure of meta-learning in general} \begin{algorithmic} \Procedure{Training}{task environment $p(\mathcal{D}, f)$, learning rates $\gamma$ and $\alpha$} \State initialise meta-parameter $\theta$ \While{$\theta$ not converged} \State sample a mini-batch of $T$ tasks from task environment $p\left( \mathcal{D}, f \right)$ \For{each task $\mathcal{T}_{i}, i \in \{1, \ldots, T\}$} \State sample two data subsets $S_{i}^{(t)}$ and $S_{i}^{(v)}$ from task $\mathcal{T}_{i} = (\mathcal{D}_{i}^{(t)}, \mathcal{D}_{i}^{(v)}, f_{i})$ \State adapt meta-parameter to task $\mathcal{T}_{i}$: $\mathbf{w}_{i}^{*} \left( \theta \right) = \theta - \frac{\alpha}{m_{i}^{(t)}} \sum_{j = 1}^{m_{i}^{(t)}} \nabla_{\theta} \left[ \ell \left( \mathbf{x}_{ij}^{(t)}, y_{ij}^{(t)}; \theta \right)\right]$ \EndFor \State update meta-parameter: $\theta \gets \theta - \frac{\gamma}{T} \sum_{i=1}^{T} \frac{1}{m_{i}^{(v)}} \sum_{k=1}^{m_{i}^{(v)}} \nabla_{\theta} \left[\ell \left( \mathbf{x}_{ik}^{(v)}, y_{ik}^{(v)}; \mathbf{w}_{i}^{*} \left( \theta \right) \right) \right]$ \EndWhile \State \textbf{return} the trained meta-parameter $\theta$ \EndProcedure \end{algorithmic} \end{algorithm}
</div>
</div>
</section>
<section id="second-order-meta-learning" class="level3" data-number="1.4">
<h3 data-number="1.4" class="anchored" data-anchor-id="second-order-meta-learning"><span class="header-section-number">1.4</span> Second-order meta-learning</h3>
<p>As shown in Equation&nbsp;4, the optimisation for the meta-parameter <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> requires the gradient of the validation loss averaged across <img src="https://latex.codecogs.com/png.latex?T"> tasks. Given that each task-specific parameter <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D"> is a function of <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> due to the lower-level optimisation in Equation&nbsp;3, the gradient of interest can be expanded as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cnabla_%7B%5Ctheta%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D%5C%5C%0A&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cnabla_%7B%5Ctheta%7D%5E%7B%5Ctop%7D%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D%20%5Cleft(%20%5Ctheta%20%5Cright)%20%5Ctimes%20%5Cnabla_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D%5C%5C%0A&amp;%20=%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%7D%20%5Cleft%5C%7B%20%5Cleft%5B%20%5Cmathbf%7BI%7D%20-%20%5Calpha%20%5Cmathbb%7BE%7D_%7B%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(t)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%20%5Ctextcolor%7Bcrimson%7D%7B%5Cnabla_%7B%5Ctheta%7D%5E%7B2%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(t)%7D,%20y_%7Bij%7D%5E%7B(t)%7D;%20%5Ctheta%20%5Cright)%7D%20%5Cright%5D%20%5Cright%5D%20%5Cright.%5C%5C%0A&amp;%20%5Cquad%20%5Ctimes%20%5Cleft.%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Ctextcolor%7Bgreen%7D%7B%5Cnabla_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%7D%20%5Cright%5D%20%5Cright%5C%7D,%0A%5Cend%7Baligned%7D%0A"> where the first equality is due to chain rule, and the second equality is the result that differentiates the gradient update in Equation&nbsp;3. Note that in the second equality, we remove the transpose notation since the corresponding matrix is symmetric.</p>
<p>Thus, naively implementing such gradient would require to calculate the Hessian matrix $ $, resulting in an intractable procedure for large models, such as deep neural networks. To obtain a more efficient implementation, one can utilise the Hessian-vector product <span class="citation" data-cites="pearlmutter94fastexact">(Pearlmutter 1994)</span> between the gradient vector <img src="https://latex.codecogs.com/png.latex?%5Ctextcolor%7Bgreen%7D%7B%5Cnabla_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20%5Cmathbf%7By%7D_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%7D"> and the Hessian matrix $ $ to efficiently calculate the gradient of the validation loss w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Ctheta">.</p>
<p>Another way to calculate the gradient of the validation loss w.r.t. the meta-parameter <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> is to use implicit differentiation <span class="citation" data-cites="domke2012generic rajeswaran2019meta lorraine2020optimizing">(Domke 2012; Rajeswaran et al. 2019; Lorraine et al. 2020)</span>. This approach is more advantaged since it does not need to stores the computational graph and takes gradient via chain rule. Such implicit differentiation technique reduces the memory usage and therefore, allows to work with large-scale models. However, the trade-off is the increasing computational time to apply the chain rule to calculate the gradient of interest.</p>
<p>Nevertheless, the implementations that compute the exact gradient of the validation loss w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> without approximation are often referred to as <q>second-order</q> meta-learning.</p>
</section>
<section id="first-order-meta-learning" class="level3" data-number="1.5">
<h3 data-number="1.5" class="anchored" data-anchor-id="first-order-meta-learning"><span class="header-section-number">1.5</span> First-order meta-learning</h3>
<p>In practice, the Hessian matrix $ $ is often omitted from the calculation to simplify the update for the meta-parameter <img src="https://latex.codecogs.com/png.latex?%5Ctheta"> <span class="citation" data-cites="finn2017model">(Finn et al. 2017)</span>. The resulting gradient consists of only the gradient of validation loss <img src="https://latex.codecogs.com/png.latex?%5Ctextcolor%7BGreen%7D%7B%5Cnabla_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bij%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%7D">, which is more efficient to calculate with a single forward-pass if auto differentiation is used. This approximation is often referred as <q>first-order</q> meta-learning, and the gradient of interest can be presented as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cnabla_%7B%5Ctheta%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D%20%5C%5C%0A&amp;%20%5Capprox%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Ctextcolor%7BGreen%7D%7B%5Cnabla_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%7D%20%5Cright%5D.%0A%5Cend%7Baligned%7D%0A"></p>
<p>REPTILE [<span class="citation" data-cites="nichol2018on">Nichol et al. (2018)</span>} — a variant first-order meta-learning — approximates further the gradient of validation loss <img src="https://latex.codecogs.com/png.latex?%5Ctextcolor%7BGreen%7D%7B%5Cnabla_%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bij%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%7D"> by the difference <img src="https://latex.codecogs.com/png.latex?%5Ctheta%20-%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D">, resulting in a much simpler approximation: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%7D%20%5Cmathbb%7BE%7D_%7B%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20%5Cmathbf%7By%7D_%7Bik%7D%5E%7B(v)%7D%20%5Cright)%20%5Csim%20%5Cleft(%20%5Cmathcal%7BD%7D_%7Bi%7D%5E%7B(v)%7D,%20f_%7Bi%7D%20%5Cright)%7D%20%5Cleft%5B%20%5Cnabla_%7B%5Ctheta%7D%20%5Cell%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bik%7D%5E%7B(v)%7D,%20y_%7Bik%7D%5E%7B(v)%7D;%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright)%20%5Cright%5D%20=%20%5Ctheta%20-%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BT%7D_%7Bi%7D%20%5Csim%20p%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%7D%20%5Cleft%5B%20%5Cmathbf%7Bw%7D_%7Bi%7D%5E%7B*%7D(%5Ctheta)%20%5Cright%5D.%0A"></p>
</section>
</section>
<section id="differentiation-from-other-transfer-learning-approaches" class="level2" data-number="2">
<h2 data-number="2" class="anchored" data-anchor-id="differentiation-from-other-transfer-learning-approaches"><span class="header-section-number">2</span> Differentiation from other transfer learning approaches</h2>
<p>In this section, some popular transfer learning methods are described with their objective functions to purposely distinguish from meta-learning.</p>
<section id="fine-tuning" class="level3" data-number="2.1">
<h3 data-number="2.1" class="anchored" data-anchor-id="fine-tuning"><span class="header-section-number">2.1</span> Fine-tuning</h3>
<p>Fine-tuning is the most common technique in neural network based transfer learning <span class="citation" data-cites="pratt1991direct yosinski2014transferable">(Pratt et al. 1991; Yosinski et al. 2014)</span> where the last or a couple of last layers in a neural network pre-trained on a source task are replaced and fine-tuned on a target task. Formally, if <img src="https://latex.codecogs.com/png.latex?g(.;%20%5Cmathbf%7Bw%7D_%7B0%7D)"> is denoted as the forward function of the shared layers with shared parameters <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7B0%7D">, where <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bs%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bt%7D"> are the parameters of the remaining layers <img src="https://latex.codecogs.com/png.latex?h"> specifically trained on source and target tasks, respectively, then the objective of fine-tuning can be expressed as: <span id="eq-fine_tuning_formulation"><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmin_%7B%5Cmathbf%7Bw%7D_%7Bt%7D%7D%20%5Cmathbb%7BE%7D_%7B(%5Cmathbf%7Bx%7D_%7Bt%7D,%20%5Cmathbf%7By%7D_%7Bt%7D)%20%5Csim%20%5Cmathcal%7BT%7D_%7Bt%7D%7D%20%5Cleft%5B%20%5Cell%20%5Cleft(%20h%5Cleft(%20g%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bt%7D;%20%5Cmathbf%7Bw%7D_%7B0%7D%5E%7B*%7D%20%5Cright);%20%5Cmathbf%7Bw%7D_%7Bt%7D%20%5Cright),%20%5Cmathbf%7By%7D_%7Bt%7D%20%5Cright)%20%5Cright%5D%20%5C%5C%0A&amp;%20%5Ctext%7Bs.t.:%20%7D%20%5Cmathbf%7Bw%7D_%7B0%7D%5E%7B*%7D,%20%5Cmathbf%7Bw%7D_%7Bs%7D%5E%7B*%7D%20=%20%5Carg%5Cmin_%7B%5Cmathbf%7Bw%7D_%7B0%7D,%20%5Cmathbf%7Bw%7D_%7Bs%7D%7D%20%5Cmathbb%7BE%7D_%7B(%5Cmathbf%7Bx%7D_%7Bs%7D,%20%5Cmathbf%7By%7D_%7Bs%7D)%20%5Csim%20%5Cmathcal%7BT%7D_%7Bs%7D%7D%20%5Cleft%5B%20%5Cell%20%5Cleft(%20h%20%5Cleft(%20g%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bs%7D;%20%5Cmathbf%7Bw%7D_%7B0%7D%20%5Cright);%20%5Cmathbf%7Bw%7D_%7Bs%7D%20%5Cright),%20%5Cmathbf%7By%7D_%7Bs%7D%20%5Cright)%20%5Cright%5D,%0A%5Cend%7Baligned%7D%0A%5Ctag%7B5%7D"></span></p>
<p>where <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D_%7Bs%7D,%20%5Cmathbf%7By%7D_%7Bs%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D_%7Bt%7D,%20%5Cmathbf%7By%7D_%7Bt%7D"> are the data sampled from the source task <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7Bs%7D"> and target task <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7Bt%7D">, respectively.</p>
<p>Although the objective of fine-tuning shown in Equation&nbsp;5 is still a bi-level optimisation, it is easier to solve than the one in meta-learning due to the following reasons:</p>
<ul>
<li>The objective in fine-tuning has only one constrain corresponding to one source task, while meta-learning has several constrains corresponding to multiple training tasks.</li>
<li>In fine-tuning, <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bt%7D"> and <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7B0%7D"> are inferred separately, while in meta-learning, the task-specific parameter is a function of the meta-parameter, resulting in a more complicated correlation.</li>
</ul>
<p>The downside of fine-tuning is the requirement of a reasonable number of training examples on the target task to fine-tune <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7Bt%7D">. In contrast, meta-learning leverages the knowledge extracted from several training tasks to quickly adapt to a new task with only a few training examples.</p>
</section>
<section id="domain-adaptation-and-generalisation" class="level3" data-number="2.2">
<h3 data-number="2.2" class="anchored" data-anchor-id="domain-adaptation-and-generalisation"><span class="header-section-number">2.2</span> Domain adaptation and generalisation</h3>
<p>Domain adaptation or domain-shift refers to the case when the joint data-label distribution on source and target are different, denoted as <img src="https://latex.codecogs.com/png.latex?p_%7Bs%7D%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)%20%5Cneq%20p_%7Bt%7D%20%5Cleft(%20%5Cmathcal%7BD%7D,%20f%20%5Cright)">, or simply <img src="https://latex.codecogs.com/png.latex?p_%7Bs%7D(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7By%7D)%20%5Cneq%20p_%7Bt%7D(%5Cmathbf%7Bx%7D,%20%5Cmathbf%7By%7D)"> <span class="citation" data-cites="heckman1979sample shimodaira2000improving japkowicz2002class daume2006domain ben2007analysis">(Heckman 1979; Shimodaira 2000; Japkowicz and Stephen 2002; Daume III and Marcu 2006; <span class="nocase">Ben-David et al.</span> 2007)</span>. The aim of domain adaptation is to leverage the model trained on source domain to available data in the target domain, so that the model adapted to the target domain can perform reasonably well. In other words, domain adaptation relies on a data transformation <img src="https://latex.codecogs.com/png.latex?g(.,%20.;%20%5Cmathbf%7Bw%7D_%7B0%7D):%20%5Cmathcal%7BX%7D%20%5Ctimes%20%5Cmathcal%7BY%7D%20%5Cto%20%5Cmathcal%7BX%7D%5E%7B%5Cprime%7D%20%5Ctimes%20%5Cmathcal%7BY%7D%5E%7B%5Cprime%7D"> that produces a domain-invariant latent space. Mathematically, the transformation <img src="https://latex.codecogs.com/png.latex?g"> is obtained by minimising a divergence between the two transformed data distribution: <span id="eq-domain_adaptation"><img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmin_%7B%5Cmathbf%7Bw%7D_%7B0%7D%7D%20%5Cmathrm%7BDivergence%7D%20%5Cleft%5B%20p%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bs%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D_%7Bs%7D%5E%7B%5Cprime%7D%20%5Cright)%20%7C%7C%20p%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bt%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D_%7Bt%7D%5E%7B%5Cprime%7D%20%5Cright)%20%5Cright%5D%5C%5C%0A&amp;%20%5Ctext%7Bs.t.:%20%7D%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bi%7D%5E%7B%5Cprime%7D,%20%5Cmathbf%7By%7D_%7Bi%7D%5E%7B%5Cprime%7D%20%5Cright)%20=%20g%20%5Cleft(%20%5Cmathbf%7Bx%7D_%7Bi%7D,%20%5Cmathbf%7By%7D_%7Bi%7D;%20%5Cmathbf%7Bw%7D_%7B0%7D%20%5Cright),%20i%20%5Cin%20%5C%7Bs,%20t%5C%7D.%0A%5Cend%7Baligned%7D%0A%5Ctag%7B6%7D"></span></p>
<p>After obtaining the transformation <img src="https://latex.codecogs.com/png.latex?g">, one can simply train a model using the transformed data of the source domain, and then use that model to make predictions on the target domain.</p>
<p>Given the optimisation in Equation&nbsp;6, domain adaptation is different from meta-learning due to the following reasons:</p>
<ul>
<li>Domain adaptation assumes a shift in the task environments that generate source and target tasks, while meta-learning is based on the assumption of same task generation.</li>
<li>Domain adaptation utilises information of data from target domain, while meta-learning does not have such access.</li>
</ul>
<p>In general, meta-learning learns a shared prior or hyper-parameters to generalise for unseen tasks, while domain adaptation produces a model to solve a particular task in a specified target domain. Recently, there is a variance of domain adaptation, named <b>domain generalisation</b>, where the aim is to learn a domain-invariant model without any information of target domain. In this view, domain generalisation is very similar to meta-learning, and there are some works that employ meta-learning algorithms for domain generalisation <span class="citation" data-cites="li2018learning li2019feature">(Li et al. 2018; Li et al. 2019)</span>.</p>
</section>
<section id="multi-task-learning" class="level3" data-number="2.3">
<h3 data-number="2.3" class="anchored" data-anchor-id="multi-task-learning"><span class="header-section-number">2.3</span> Multi-task learning</h3>
<p>Multi-task learning learns several related auxiliary tasks and a target task simultaneously to exploit the diversity of task representation to regularise and improve the performance on the target task <span class="citation" data-cites="caruana1997multitask">(Caruana 1997)</span>. If the input <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"> is assumed to be the same across <img src="https://latex.codecogs.com/png.latex?T"> extra tasks and the target task <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7BT%20+%201%7D">, then the objective of multi-task learning can be expressed as: <span id="eq-mtl_formulation"><img src="https://latex.codecogs.com/png.latex?%0A%5Cmin_%7B%5Cmathbf%7Bw%7D_%7B0%7D,%20%5C%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BT%20+%201%7D%7D%20%5Cfrac%7B1%7D%7BT%20+%201%7D%20%5Csum_%7Bi%20=%201%7D%5E%7BT%20+%201%7D%20%5Cell_%7Bi%7D%20%5Cleft(%20h_%7Bi%7D%20%5Cleft(%20g%5Cleft(%20%5Cmathbf%7Bx%7D;%20%5Cmathbf%7Bw%7D_%7B0%7D%20%5Cright);%20%5Cmathbf%7Bw%7D_%7Bi%7D%20%5Cright),%20%5Cmathbf%7By%7D_%7Bi%7D%20%5Cright),%0A%5Ctag%7B7%7D"></span> where <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7By%7D_%7Bi%7D,%20%5Cell_%7Bi%7D"> and <img src="https://latex.codecogs.com/png.latex?h_%7Bi%7D"> are the label, loss function and the classifier for task <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BT%7D_%7Bi%7D">, respectively, and <img src="https://latex.codecogs.com/png.latex?g(.,%20%5Cmathbf%7Bw%7D_%7B0%7D)"> is the shared feature extractor for <img src="https://latex.codecogs.com/png.latex?T%20+%201"> tasks.</p>
<p>Multi-task learning is often confused with meta-learning due to their similar nature extracting information from many tasks. However, the objective function of multi-task learning in Equation&nbsp;7 is a single-level optimisation for the shared parameter <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D_%7B0%7D"> and multiple task-specific classifier <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Cmathbf%7Bw%7D_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BT%20+%201%7D">. It is, therefore, not as complicated as a bi-level optimisation seen in meta-learning as shown in Equation&nbsp;2. Furthermore, multi-task learning aims to solve a number of specific tasks known during training (referred to as target tasks), while meta-learning targets the generalisation for unseen tasks in the future.</p>
</section>
<section id="continual-learning" class="level3" data-number="2.4">
<h3 data-number="2.4" class="anchored" data-anchor-id="continual-learning"><span class="header-section-number">2.4</span> Continual learning</h3>
<p>Continual or <em>life-long learning</em> refers to a situation where a learning agent has access to a continuous stream of tasks available over time, and the number of tasks to be learnt is not pre-defined <span class="citation" data-cites="chen2018lifelong parisi2019continual">(Chen and Liu 2018; Parisi et al. 2019)</span>. The aim is to accommodate the knowledge extracted from one-time observed tasks to accelerate the learning of new tasks without catastrophically forgetting old tasks <span class="citation" data-cites="french1999catastrophic">(French 1999)</span>. In this sense, continual learning is very similar to meta-learning. However, continual learning most likely focuses on <b>systematic</b> design to acquire new knowledge in such a way that prevents interfering to the existing one, while meta-learning is more about <b>algorithmic</b> design to learn the new knowledge more efficiently. Thus, we cannot mathematically distinguish their differences as done in sub-sections Fine-tuning, Domain adaptation and generalisation and Multi-task learning . Nevertheless, continual learning criteria, especially catastrophic forgetting, can be encoded into meta-learning objective to advance further continual learning performance <span class="citation" data-cites="al2018continuous nagabandi2019learning">(Al-Shedivat et al. 2018; Nagabandi et al. 2019)</span>.</p>
</section>
</section>
<section id="summary" class="level2" data-number="3">
<h2 data-number="3" class="anchored" data-anchor-id="summary"><span class="header-section-number">3</span> Summary</h2>
<p>In general, meta-learning is an extension of hyper-parameter optimisation in multi-task setting. The objective function of meta-learning is, therefore, a bi-level optimisation, where the lower-level is to adapt the meta-parameter to a task, while the upper-level is to evaluate how well the meta-parameter performs across <img src="https://latex.codecogs.com/png.latex?T"> tasks. Given such mathematical formulation, we can easily distinguish meta-learning from some common transfer learning approaches, such as fine-tuning, multi-task learning, domain adaptation and continual learning.</p>
<p>Hope that this post would give another perspective of meta-learning. I’ll see you in the next post about probabilistic methods in meta-learning.</p>
</section>
<section id="references" class="level2" data-number="4">
<h2 data-number="4" class="anchored" data-anchor-id="references"><span class="header-section-number">4</span> References</h2>
<div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-al2018continuous" class="csl-entry">
Al-Shedivat, Maruan, Trapit Bansal, Yuri Burda, Ilya Sutskever, Igor Mordatch, and Pieter Abbeel. 2018. <span>“Continuous Adaptation via Meta-Learning in Nonstationary and Competitive Environments.”</span> <em>International Conference on Learning Representation</em>.
</div>
<div id="ref-baxter2000model" class="csl-entry">
Baxter, Jonathan. 2000. <span>“A Model of Inductive Bias Learning.”</span> <em>Journal of Artificial Intelligence Research</em> 12: 149–98.
</div>
<div id="ref-ben2007analysis" class="csl-entry">
<span class="nocase">Ben-David, Shai, John Blitzer, Koby Crammer, Fernando Pereira, et al.</span> 2007. <span>“Analysis of Representations for Domain Adaptation.”</span> <em>Advances in Neural Information Processing Systems</em> 19: 137.
</div>
<div id="ref-caruana1997multitask" class="csl-entry">
Caruana, Rich. 1997. <span>“Multitask Learning.”</span> <em>Machine Learning</em> 28 (1): 41–75.
</div>
<div id="ref-chen2018lifelong" class="csl-entry">
Chen, Zhiyuan, and Bing Liu. 2018. <span>“Lifelong Machine Learning.”</span> <em>Synthesis Lectures on Artificial Intelligence and Machine Learning</em> 12 (3): 1–207.
</div>
<div id="ref-daume2006domain" class="csl-entry">
Daume III, Hal, and Daniel Marcu. 2006. <span>“Domain Adaptation for Statistical Classifiers.”</span> <em>Journal of Artificial Intelligence Research</em> 26: 101–26.
</div>
<div id="ref-domke2012generic" class="csl-entry">
Domke, Justin. 2012. <span>“Generic Methods for Optimization-Based Modeling.”</span> <em>Artificial Intelligence and Statistics</em>, 318–26.
</div>
<div id="ref-finn2017model" class="csl-entry">
Finn, Chelsea, Pieter Abbeel, and Sergey Levine. 2017. <span>“Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.”</span> <em>International Conference on Machine Learning</em>, 1126–35.
</div>
<div id="ref-french1999catastrophic" class="csl-entry">
French, Robert M. 1999. <span>“Catastrophic Forgetting in Connectionist Networks.”</span> <em>Trends in Cognitive Sciences</em> 3 (4): 128–35.
</div>
<div id="ref-heckman1979sample" class="csl-entry">
Heckman, James J. 1979. <span>“Sample Selection Bias as a Specification Error.”</span> <em>Econometrica: Journal of the Econometric Society</em>, 153–61.
</div>
<div id="ref-hospedales2021meta" class="csl-entry">
Hospedales, Timothy M, Antreas Antoniou, Paul Micaelli, and Amos J Storkey. 2021. <span>“Meta-Learning in Neural Networks: A Survey.”</span> <em>IEEE Transactions on Pattern Analysis and Machine Intelligence</em>.
</div>
<div id="ref-japkowicz2002class" class="csl-entry">
Japkowicz, Nathalie, and Shaju Stephen. 2002. <span>“The Class Imbalance Problem: A Systematic Study.”</span> <em>Intelligent Data Analysis</em> 6 (5): 429–49.
</div>
<div id="ref-li2018learning" class="csl-entry">
Li, Da, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales. 2018. <span>“Learning to Generalize: Meta-Learning for Domain Generalization.”</span> <em>Thirty-Second AAAI Conference on Artificial Intelligence</em>.
</div>
<div id="ref-li2019feature" class="csl-entry">
Li, Yiying, Yongxin Yang, Wei Zhou, and Timothy Hospedales. 2019. <span>“Feature-Critic Networks for Heterogeneous Domain Generalization.”</span> <em>International Conference on Machine Learning</em>, 3915–24.
</div>
<div id="ref-li2017meta" class="csl-entry">
Li, Zhenguo, Fengwei Zhou, Fei Chen, and Hang Li. 2017. <span>“Meta-Sgd: Learning to Learn Quickly for Few-Shot Learning.”</span> <em>arXiv Preprint arXiv:1707.09835</em>.
</div>
<div id="ref-lorraine2020optimizing" class="csl-entry">
Lorraine, Jonathan, Paul Vicol, and David Duvenaud. 2020. <span>“Optimizing Millions of Hyperparameters by Implicit Differentiation.”</span> <em>International Conference on International Conference on Artificial Intelligence and Statistics</em>, 1540–52.
</div>
<div id="ref-nagabandi2019learning" class="csl-entry">
Nagabandi, Anusha, Ignasi Clavera, Simin Liu, et al. 2019. <span>“Learning to Adapt in Dynamic, Real-World Environments Through Meta-Reinforcement Learning.”</span> <em>International Conference on Learning Representation</em>.
</div>
<div id="ref-naik1992meta" class="csl-entry">
Naik, Devang K, and RJ Mammone. 1992. <span>“Meta-Neural Networks That Learn by Learning.”</span> <em>International Joint Conference on Neural Networks</em> 1: 437–42.
</div>
<div id="ref-nichol2018on" class="csl-entry">
Nichol, Alex, Joshua Achiam, and John Schulman. 2018. <span>“On First-Order Meta-Learning Algorithms.”</span> <em>CoRR</em> abs/1803.02999. <a href="http://arxiv.org/abs/1803.02999">http://arxiv.org/abs/1803.02999</a>.
</div>
<div id="ref-parisi2019continual" class="csl-entry">
Parisi, German I, Ronald Kemker, Jose L Part, Christopher Kanan, and Stefan Wermter. 2019. <span>“Continual Lifelong Learning with Neural Networks: A Review.”</span> <em>Neural Networks</em> 113: 54–71.
</div>
<div id="ref-pearlmutter94fastexact" class="csl-entry">
Pearlmutter, Barak A. 1994. <span>“Fast Exact Multiplication by the <span>Hessian</span>.”</span> <em>Neural Computation</em> 6: 147–60.
</div>
<div id="ref-pratt1991direct" class="csl-entry">
Pratt, Lorien Y, Jack Mostow, Candace A Kamm, and Ace A Kamm. 1991. <span>“Direct Transfer of Learned Information Among Neural Networks.”</span> <em>Aaai</em> 91: 584–89.
</div>
<div id="ref-rajeswaran2019meta" class="csl-entry">
Rajeswaran, Aravind, Chelsea Finn, Sham Kakade, and Sergey Levine. 2019. <em>Meta-Learning with Implicit Gradients</em>.
</div>
<div id="ref-schmidhuber1987evolutionary" class="csl-entry">
Schmidhuber, Jürgen. 1987. <span>“Evolutionary Principles in Self-Referential Learning (on Learning How to Learn: The Meta-Meta-... Hook).”</span> Diploma thesis, Technische Universit<span>ä</span>t M<span>ü</span>nchen.
</div>
<div id="ref-shimodaira2000improving" class="csl-entry">
Shimodaira, Hidetoshi. 2000. <span>“Improving Predictive Inference Under Covariate Shift by Weighting the Log-Likelihood Function.”</span> <em>Journal of Statistical Planning and Inference</em> 90 (2): 227–44.
</div>
<div id="ref-snell2017prototypical" class="csl-entry">
Snell, Jake, Kevin Swersky, and Richard Zemel. 2017. <span>“Prototypical Networks for Few-Shot Learning.”</span> <em>Advances in Neural Information Processing Systems</em>, 4077–87.
</div>
<div id="ref-vinyals2016matching" class="csl-entry">
<span class="nocase">Vinyals, Oriol, Charles Blundell, Timothy Lillicrap, Daan Wierstra, et al.</span> 2016. <span>“Matching Networks for One Shot Learning.”</span> <em>Advances in Neural Information Processing Systems</em> 29: 3630–38.
</div>
<div id="ref-yosinski2014transferable" class="csl-entry">
Yosinski, Jason, Jeff Clune, Yoshua Bengio, and Hod Lipson. 2014. <span>“How Transferable Are Features in Deep Neural Networks?”</span> <em>Advances in Neural Information Processing Systems</em>.
</div>
</div>


<!-- -->

</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a><div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by/4.0/">CC BY 4.0</a></div></div></section><section class="quarto-appendix-contents" id="quarto-citation"><h2 class="anchored quarto-appendix-heading">Citation</h2><div><div class="quarto-appendix-secondary-label">BibTeX citation:</div><pre class="sourceCode code-with-copy quarto-appendix-bibtex"><code class="sourceCode bibtex">@online{nguyen2021,
  author = {Nguyen, Cuong},
  title = {From Hyper-Parameter Optimisation to Meta-Learning},
  date = {2021-11-22},
  url = {https://cnguyen10.github.io/posts/meta-learning/},
  langid = {en}
}
</code></pre><div class="quarto-appendix-secondary-label">For attribution, please cite this work as:</div><div id="ref-nguyen2021" class="csl-entry quarto-appendix-citeas">
Nguyen, Cuong. 2021. <span>“From Hyper-Parameter Optimisation to
Meta-Learning.”</span> November 22. <a href="https://cnguyen10.github.io/posts/meta-learning/">https://cnguyen10.github.io/posts/meta-learning/</a>.
</div></div></section></div> ]]></description>
  <category>Meta-Learning</category>
  <category>Transfer Learning</category>
  <category>Optimisation</category>
  <guid>https://cnguyen10.github.io/posts/meta-learning/</guid>
  <pubDate>Mon, 22 Nov 2021 00:00:00 GMT</pubDate>
</item>
<item>
  <title>Outer product approximation of Hessian matrix</title>
  <dc:creator>Cuong Nguyen</dc:creator>
  <link>https://cnguyen10.github.io/posts/Gauss-Newton-matrix/</link>
  <description><![CDATA[ 




<p>Hessian matrix is heavily studied in the optimisation community. The purpose is to utilise the second order derivative to optimise a function of interest (also known as Newton’s method). In machine learning, especially Bayesian inference, Hessian matrix can be found in some applications, such as Laplace’s method which approximates a distribution by a Gaussian distribution. Although Hessian matrix provides additional information which improves the convergence rate in optimisation or reduces a complicated distribution to a Gaussian distribution, calculating a Hessian matrix often increases computation complexity. In neural networks where the number of model parameters is very large, Hessian matrix is often intractable due to the limited computation and memory.</p>
<p>Many efficient approximations of Hessian matrix have been developed to either reduce the running time complexity or decompose the Hessian matrix to reduce the amount of memory storage. Hessian-free approaches which utilises the Hessian-vector product are also attracted much research interest. This post will present an approximation of Hessian matrix using the outer product. Note that this approximation represents an approximated Hessian matrix by a set of matrices whose sizes are reasonable to store in GPU memory. The trade-off is that the running time complexity to obtain the Hessian matrix is still quadratic. Note that this approximation is also known as Gauss-Newton matrix.</p>
<section id="notations" class="level2" data-number="1">
<h2 data-number="1" class="anchored" data-anchor-id="notations"><span class="header-section-number">1</span> Notations</h2>
<p>Before going into details, let’s define some notations used:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?%5C%7Bx_%7Bi%7D,%20t_%7Bi%7D%5C%7D_%7Bi%20=%201%7D%5E%7BN%7D"> is the input and label of data-point <img src="https://latex.codecogs.com/png.latex?i">-th,</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bw%7D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BW%7D"> is the parameter of the model of interest, or the weight of a neural network,</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cell(.)%20%5Cin%20%5Cmathbb%7BR%7D"> is the loss function, e.g.&nbsp;MSE or cross-entropy,</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BC%7D"> is the pre-nonlinearity output of the neural network at the final layer that has <img src="https://latex.codecogs.com/png.latex?C"> hidden units,</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Csigma%5Cleft%5B%20%5Cmathbf%7Bf%7D%5Cleft(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D%5Cright)%20%5Cright%5D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BC%7D"> is the activation output at the final layer. For example, in regression, <img src="https://latex.codecogs.com/png.latex?%5Csigma(z)%20=%20z"> is the identity function, or in logistic regression, <img src="https://latex.codecogs.com/png.latex?%5Csigma(.)"> is the sigmoid function, while in multi-class classification, <img src="https://latex.codecogs.com/png.latex?%5Csigma(.)"> is the softmax function,</li>
</ul>
<p>The loss function of interest is defined as the sum of losses over each data point: <img src="https://latex.codecogs.com/png.latex?%0AL%20=%20%5Csum_%7Bi%20=%201%7D%5E%7BN%7D%20%5Cell%5Cleft(%20%5Csigma(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D),%20t_%7Bi%7D%5Cright).%0A"> Note that in the following, we will omit the notation of the label <img src="https://latex.codecogs.com/png.latex?t_%7Bi%7D"> from the loss <img src="https://latex.codecogs.com/png.latex?%5Cell(.)"> to make the notation uncluttered.</p>
</section>
<section id="derivation-of-the-approximated-hessian-matrix" class="level2" data-number="2">
<h2 data-number="2" class="anchored" data-anchor-id="derivation-of-the-approximated-hessian-matrix"><span class="header-section-number">2</span> Derivation of the approximated Hessian matrix</h2>
<p>An element of the Hessian matrix can then be written as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%5Cmathbf%7BH%7D_%7Bjk%7D%20&amp;%20=%20%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%5Cmathbf%7Bw%7D_%7Bk%7D%7D%20%5Cleft(%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%7D%20%5Cright)%20=%20%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%5Cmathbf%7Bw%7D_%7Bk%7D%7D%20%5Cleft(%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Cfrac%7B%5Cpartial%20%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%7D%20%5Cright)%20%5C%5C%0A&amp;%20=%20%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bk%7D%7D%20%5Cleft(%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Csum_%7Bc=1%7D%5E%7BC%7D%20%5Cfrac%7B%5Cpartial%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%7D%20%5Cright)%20%5Cquad%20%5Ctext%7B%5Ctextcolor%7BForestGreen%7D%7B(chain%20rule)%7D%7D%5C%5C%0A&amp;%20=%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Csum_%7Bc=1%7D%5E%7BC%7D%20%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bk%7D%7D%20%5Cleft(%20%5Cfrac%7B%5Cpartial%20%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%7D%20%5Cright).%0A%5Cend%7Baligned%7D%0A"></p>
<p>Applying the chain rule for the first term gives: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%5Cmathbf%7BH%7D_%7Bjk%7D%20&amp;%20=%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Csum_%7Bc=1%7D%5E%7BC%7D%20%5Cleft%5B%20%5Csum_%7Bl=1%7D%5E%7BC%7D%20%5Cleft(%20%5Cfrac%7B%5Cpartial%5E%7B2%7D%20%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5C,%20%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bl%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bl%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bk%7D%7D%20%5Cright)%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%7D%20%5Cright%5D%20%5C%5C%0A&amp;%20%5Cqquad%20%5Cqquad%20%5Cquad%20+%20%5Cfrac%7B%5Cpartial%20%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20%5Cfrac%7B%5Cpartial%5E%7B2%7D%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%20%5C,%20%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bk%7D%7D.%0A%5Cend%7Baligned%7D%0A"></p>
<p>Rearranging gives: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%5Cmathbf%7BH%7D_%7Bjk%7D%20&amp;%20=%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Csum_%7Bc=1%7D%5E%7BC%7D%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%7D%20%5Csum_%7Bl=1%7D%5E%7BC%7D%20%5Cfrac%7B%5Cpartial%5E%7B2%7D%20%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5C,%20%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bl%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bl%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bk%7D%7D%20%5C%5C%0A&amp;%20%5Cquad%20+%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Csum_%7Bc=1%7D%5E%7BC%7D%20%5Cunderbrace%7B%5Cfrac%7B%5Cpartial%20%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7D_%7B%5Capprox%200%7D%20%5Cfrac%7B%5Cpartial%5E%7B2%7D%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%20%5C,%20%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bk%7D%7D.%0A%5Cend%7Baligned%7D%0A"></p>
<p>Near the optimum, the scalar <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bf%7D_%7Bc%7D"> would be very closed to its target <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bt%7D_%7Bic%7D">. Hence, the derivative of the loss w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bf%7D_%7Bc%7D"> is very small, and we can approximate the Hessian as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbf%7BH%7D_%7Bjk%7D%20%5Capprox%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Csum_%7Bc=1%7D%5E%7BC%7D%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bj%7D%7D%20%5Csum_%7Bl=1%7D%5E%7BC%7D%20%5Cfrac%7B%5Cpartial%5E%7B2%7D%20%5Cell%20%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%5Cright%5D%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%20(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5C,%20%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bl%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20%5Cfrac%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bl%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bw%7D_%7Bk%7D%7D.%0A"></p>
<p>Rewriting this with matrix notation yields a much simpler formulation: <img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%5Cmathbf%7BH%7D%20%5Capprox%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Cmathbf%7BJ%7D_%7Bfi%7D%5E%7B%5Ctop%7D%20%5Cmathbf%7BH%7D_%7B%5Csigma%20i%7D%20%5Cmathbf%7BJ%7D_%7Bfi%7D,%0A%7D%0A"> where: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%5Cmathbf%7BJ%7D_%7Bfi%7D%20&amp;%20=%20%5Cnabla_%7B%5Cmathbf%7Bw%7D%7D%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BC%20%5Ctimes%20W%7D%20%5Cquad%20%5Ctext%7B%5Ctextcolor%7BForestGreen%7D%7B(Jacobian%20matrix%20of%20%5Ctextbf%7Bf%7D%20w.r.t.%20%5Ctextbf%7Bw%7D)%7D%7D%5C%5C%0A%5Cmathbf%7BH%7D_%7B%5Csigma%20i%7D%20&amp;%20=%20%5Cnabla_%7B%5Cmathbf%7Bf%7D%7D%5E%7B2%7D%20%5Cell%5Cleft%5B%20%5Csigma%20%5Cleft(%20%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D%20%5Cright)%20%5Cright%5D%20%5Cin%20%5Cmathbb%7BR%7D%5E%7BC%20%5Ctimes%20C%7D%20%5Cquad%20%5Ctext%7B%5Ctextcolor%7BForestGreen%7D%7B(Hessian%20of%20loss%20w.r.t.%20%5Ctextbf%7Bf%7D)%7D%7D.%0A%5Cend%7Baligned%7D%0A"></p>
<p>Note that the Hessian matrix <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7BH%7D_%7B%5Csigma%7D"> can be manually calculated.</p>
<div class="proof remark">
<p><span class="proof-title"><em>Remark</em>. </span>Instead of storing the Hessian matrix <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7BH%7D"> with size <img src="https://latex.codecogs.com/png.latex?%7BW%20%5Ctimes%20W%7D"> which needs a large amount of memory, we can store the two matrices <img src="https://latex.codecogs.com/png.latex?%5C%7B%5Cmathbf%7BJ%7D_%7Bfi%7D,%20%5Cmathbf%7BH%7D_%7B%5Csigma%20i%7D%5C%7D_%7Bi=1%7D%5E%7BN%7D">. This will reduce the amount of memory required. Of course, the trade-off is the increasing of the computation when performing the multiplication to obtain the Hessian matrix <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7BH%7D">.</p>
</div>
<p>The following section will present how to calculate the matrix <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7BH%7D_%7B%5Csigma%7D"> for some commonly-used losses.</p>
</section>
<section id="derivation-for-mathbfh_sigma" class="level2" data-number="3">
<h2 data-number="3" class="anchored" data-anchor-id="derivation-for-mathbfh_sigma"><span class="header-section-number">3</span> Derivation for <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7BH%7D_%7B%5Csigma%7D"></h2>
<section id="mean-square-error-in-regression" class="level3" data-number="3.1">
<h3 data-number="3.1" class="anchored" data-anchor-id="mean-square-error-in-regression"><span class="header-section-number">3.1</span> Mean square error in regression</h3>
<p>In the regression:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?C%20=%201"></li>
<li><img src="https://latex.codecogs.com/png.latex?%5Csigma(.)"> is the identity function</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cell(f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20=%20%5Cfrac%7B1%7D%7B2%7D%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20-%20t_%7Bi%7D%20%5Cright)%5E%7B2%7D">.</li>
</ul>
<p>Hence, <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7BH%7D_%7B%5Csigma%7D%20=%20%5Cmathbf%7BI%7D_%7B1%7D">, resulting in: <img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20%5Cmathbf%7BH%7D%20=%20%5Csum_%7Bi=1%7D%5E%7BN%7D%20%5Cmathbf%7BJ%7D_%7Bfi%7D%5E%7B%5Ctop%7D%20%5Cmathbf%7BJ%7D_%7Bfi%7D,%0A%7D%0A"> which agrees with the results in <span class="citation" data-cites="bishop2006pattern">(Bishop and Nasrabadi 2006 - Eq.(5.84))</span>.</p>
</section>
<section id="logistic-regression" class="level3" data-number="3.2">
<h3 data-number="3.2" class="anchored" data-anchor-id="logistic-regression"><span class="header-section-number">3.2</span> Logistic regression</h3>
<p>In this case:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?C%20=%201"></li>
<li><img src="https://latex.codecogs.com/png.latex?%5Csigma(.)"> is the sigmoid function</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cell(%5Csigma(f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D))%20=%20-%20t_%7Bi%7D%20%5Cln%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20-%20(1%20-%20t_%7Bi%7D)%20%5Cln%20%5Cleft(%201%20-%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20%5Cright)">.</li>
</ul>
<p>The first derivative is expressed as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cfrac%7B%5Cpartial%20%5Cell(%5Csigma(f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D))%7D%7B%5Cpartial%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20=%20-%20t_%7Bi%7D%20%5Cleft(%201%20-%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20%5Cright)%20+%20(1%20-%20t_%7Bi%7D)%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20=%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20-%20t_%7Bi%7D.%0A"></p>
<p>The second derivative is therefore: <img src="https://latex.codecogs.com/png.latex?%0A%5Cfrac%7B%5Cpartial%5E%7B2%7D%20%5Cell(%5Csigma(f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D))%7D%7B%5Cpartial%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%5E%7B2%7D%7D%20=%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20%5Cleft%5B%201%20-%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20%5Cright%5D.%0A"></p>
<p>Hence: <img src="https://latex.codecogs.com/png.latex?%0A%5Cboxed%7B%0A%20%20%20%20%5Cmathbf%7BH%7D%20%5Capprox%20%5Csum_%7Bi=1%7D%5E%7Bn%7D%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20%5Cleft%5B%201%20-%20%5Csigma%20%5Cleft(%20f(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%20%5Cright)%20%5Cright%5D%20%5Cmathbf%7BJ%7D_%7Bfi%7D%5E%7B%5Ctop%7D%20%5Cmathbf%7BJ%7D_%7Bfi%7D,%0A%7D%0A"> which agrees with the result derived in the literature <span class="citation" data-cites="bishop2006pattern">(Bishop and Nasrabadi 2006 - Eq. (5.85))</span>.</p>
</section>
<section id="cross-entropy-loss-in-classification" class="level3" data-number="3.3">
<h3 data-number="3.3" class="anchored" data-anchor-id="cross-entropy-loss-in-classification"><span class="header-section-number">3.3</span> Cross entropy loss in classification</h3>
<p>In this case:</p>
<ul>
<li><img src="https://latex.codecogs.com/png.latex?%5Csigma(%5Cmathbf%7Bf%7D)"> is the softmax function,</li>
<li><img src="https://latex.codecogs.com/png.latex?%5Cell(%5Csigma(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)))%20=%20-%5Csum_%7Bc=1%7D%5E%7BC%7D%20%5Cmathbf%7Bt%7D_%7Bic%7D%20%5Cln%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D))">.</li>
</ul>
<p>According to the definition of the softmax function: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Csigma_%7Bc%7D%20%5Cleft(%20%5Cmathbf%7Bf%7D%20%5Cright)%20=%20%5Cfrac%7B%5Cexp(%5Cmathbf%7Bf%7D_%7Bc%7D)%7D%7B%5Csum_%7Bk=1%7D%5E%7BC%7D%20%5Cexp(%5Cmathbf%7Bf%7D_%7Bk%7D)%7D.%0A"></p>
<p>Hence, the derivative can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%7D%20=%20%5Cfrac%7B%5Cexp(%5Cmathbf%7Bf%7D_%7Bc%7D)%20%5Csum_%7Bk=1%7D%5E%7BC%7D%20%5Cexp(%5Cmathbf%7Bf%7D_%7Bk%7D)%20-%20%5Cexp(2%20%5Cmathbf%7Bf%7D_%7Bc%7D)%7D%7B%5Cleft%5B%20%5Csum_%7Bk=1%7D%5E%7BC%7D%20%5Cexp(%5Cmathbf%7Bf%7D_%7Bk%7D)%20%5Cright%5D%5E%7B2%7D%7D%20=%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%20%5Cleft%5B%201%20-%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%20%5Cright%5D,%0A"> and <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bk%7D%7D%20=%20-%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7Bk%7D(%5Cmathbf%7Bf%7D),%20%5Cforall%20k%20%5Cneq%20j.%0A"></p>
<p>An element of the Jacobian vector of the loss w.r.t. <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bf%7D"> can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cfrac%7B%5Cpartial%20%5Cell(%5Csigma(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)))%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)%7D%20&amp;%20=%20-%20%5Csum_%7Bk=1%7D%5E%7BC%7D%20%5Cfrac%7B%5Cmathbf%7Bt%7D_%7Bik%7D%7D%7B%5Csigma_%7Bk%7D(%5Cmathbf%7Bf%7D)%7D%20%5Cfrac%7B%5Cpartial%20%5Csigma_%7Bk%7D(%5Cmathbf%7Bf%7D)%7D%7B%5Cpartial%20%5Cmathbf%7Bf%7D_%7Bc%7D%7D%20%5C%5C%0A%20%20%20%20&amp;%20=%20-%20%5Cmathbf%7Bt%7D_%7Bic%7D%20%5Cleft%5B%201%20-%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%20%5Cright%5D%20+%20%5Csum_%7B%5Csubstack%7Bk=1%5C%5Ck%20%5Cneq%20c%7D%7D%5E%7BC%7D%20%5Cmathbf%7Bt%7D_%7Bik%7D%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%20%5C%5C%0A%20%20%20%20&amp;%20=%20-%20%5Cmathbf%7Bt%7D_%7Bic%7D%20+%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%20%5Cunderbrace%7B%5Csum_%7Bk=1%7D%5E%7BC%7D%20%5Cmathbf%7Bt%7D_%7Bik%7D%7D_%7B1%7D%5C%5C%0A%20%20%20%20&amp;%20=%20%5Csigma_%7Bc%7D(%5Cmathbf%7Bf%7D)%20-%20%5Cmathbf%7Bt%7D_%7Bic%7D.%0A%5Cend%7Baligned%7D%0A"></p>
<p>Hence, the Jacobian vector can be expressed as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cnabla_%7B%5Cmathbf%7Bf%7D%7D%20%5Cell(%5Csigma(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)))%20=%20%5Csigma(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D))%20-%20%5Cmathbf%7Bt%7D_%7Bi%7D.%0A"></p>
<p>The Hessian matrix is given as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cnabla_%7B%5Cmathbf%7Bf%7D%7D%5E%7B2%7D%20%5Cell(%5Csigma(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)))%20=%20%5Cnabla_%7B%5Cmathbf%7Bf%7D%7D%20%5Csigma(%5Cmathbf%7Bf%7D(x_%7Bi%7D,%20%5Cmathbf%7Bw%7D)).%0A"></p>
<p>Or, in the explicit matrix form: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbf%7BH%7D_%7B%5Csigma%7D%20=%20%5Cbegin%7Bbmatrix%7D%0A%20%20%20%20%5Csigma_%7B1%7D(%5Cmathbf%7Bf%7D)%20%5Cleft%5B%201%20-%20%5Csigma_%7B1%7D(%5Cmathbf%7Bf%7D)%20%5Cright%5D%20&amp;%20-%20%5Csigma_%7B1%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7B2%7D(%5Cmathbf%7Bf%7D)%20&amp;%20-%20%5Csigma_%7B1%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7B3%7D(%5Cmathbf%7Bf%7D)%20&amp;%20%5Cldots%20&amp;%20-%20%5Csigma_%7B1%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7BC%7D(%5Cmathbf%7Bf%7D)%5C%5C%0A%20%20%20%20-%20%5Csigma_%7B2%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7B1%7D(%5Cmathbf%7Bf%7D)%20&amp;%20%5Csigma_%7B2%7D(%5Cmathbf%7Bf%7D)%20%5Cleft%5B%201%20-%20%5Csigma_%7B2%7D(%5Cmathbf%7Bf%7D)%20%5Cright%5D%20&amp;%20-%20%5Csigma_%7B2%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7B3%7D(%5Cmathbf%7Bf%7D)%20&amp;%20%5Cldots%20&amp;%20-%20%5Csigma_%7B2%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7BC%7D(%5Cmathbf%7Bf%7D)%5C%5C%0A%20%20%20%20%5Cvdots%20&amp;%20%5Cvdots%20&amp;%20%5Cddots%20&amp;%20%5Cvdots%20&amp;%20%5Cvdots%5C%5C%0A%20%20%20%20-%20%5Csigma_%7BC%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7B1%7D(%5Cmathbf%7Bf%7D)%20&amp;%20-%20%5Csigma_%7BC%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7B2%7D(%5Cmathbf%7Bf%7D)%20&amp;%20-%20%5Csigma_%7BC%7D(%5Cmathbf%7Bf%7D)%20%5Csigma_%7B3%7D(%5Cmathbf%7Bf%7D)%20&amp;%20%5Cldots%20&amp;%20%5Csigma_%7BC%7D(%5Cmathbf%7Bf%7D)%20%5Cleft%5B%201%20-%20%5Csigma_%7BC%7D(%5Cmathbf%7Bf%7D)%20%5Cright%5D%0A%20%20%20%20%5Cend%7Bbmatrix%7D.%0A"></p>
</section>
</section>
<section id="conclusion" class="level2" data-number="4">
<h2 data-number="4" class="anchored" data-anchor-id="conclusion"><span class="header-section-number">4</span> Conclusion</h2>
<p>In this post, we derive an approximation of the Hessian matrix. The Gauss-Newton matrix is a good approximation since it is positive-definite and more efficient to store under the form of a set of smaller matrices. Of course, we have not got away from the curse of dimensionality since the running time complexity to obtain the Hessian matrix is still quadratic w.r.t. the number of the model parameters. One final note is that one should use the approximated Hessian matrix with care since the approximation is assumed to be near the minimal value of the considered loss function.</p>
</section>
<section id="references" class="level2" data-number="5">
<h2 data-number="5" class="anchored" data-anchor-id="references"><span class="header-section-number">5</span> References</h2>
<div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-bishop2006pattern" class="csl-entry">
Bishop, Christopher M, and Nasser M Nasrabadi. 2006. <em>Pattern Recognition and Machine Learning</em>. Vol. 4. Springer.
</div>
</div>


<!-- -->

</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a><div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by/4.0/">CC BY 4.0</a></div></div></section><section class="quarto-appendix-contents" id="quarto-citation"><h2 class="anchored quarto-appendix-heading">Citation</h2><div><div class="quarto-appendix-secondary-label">BibTeX citation:</div><pre class="sourceCode code-with-copy quarto-appendix-bibtex"><code class="sourceCode bibtex">@online{nguyen2021,
  author = {Nguyen, Cuong},
  title = {Outer Product Approximation of {Hessian} Matrix},
  date = {2021-04-12},
  url = {https://cnguyen10.github.io/posts/Gauss-Newton-matrix/},
  langid = {en}
}
</code></pre><div class="quarto-appendix-secondary-label">For attribution, please cite this work as:</div><div id="ref-nguyen2021" class="csl-entry quarto-appendix-citeas">
Nguyen, Cuong. 2021. <span>“Outer Product Approximation of Hessian
Matrix.”</span> April 12. <a href="https://cnguyen10.github.io/posts/Gauss-Newton-matrix/">https://cnguyen10.github.io/posts/Gauss-Newton-matrix/</a>.
</div></div></section></div> ]]></description>
  <category>Optimisation</category>
  <category>Machine Learning</category>
  <category>Linear Algebra</category>
  <guid>https://cnguyen10.github.io/posts/Gauss-Newton-matrix/</guid>
  <pubDate>Mon, 12 Apr 2021 00:00:00 GMT</pubDate>
</item>
<item>
  <title>PAC-Bayes bounds for generalisation error</title>
  <dc:creator>Cuong Nguyen</dc:creator>
  <link>https://cnguyen10.github.io/posts/PAC-Bayes-bounds/</link>
  <description><![CDATA[ 




<p>Properly approaximately correct (PAC) learning is a part of <em>statistical machine learning</em> which has been a fundamental course for most of graduate programs in machine learning. Its main idea is to upper-bound the <em>true risk</em> (or generalisation error) by the <em>empirical risk</em> with certain confidence level. In other words, it is often written in the following form: <img src="https://latex.codecogs.com/png.latex?%0A%5CPr%20(%5Ctext%7Btrue%20risk%7D%20%5Cle%20%5Ctext%7Bempirical%20risk%7D%20+%20r(m,%20%5Cdelta))%20%5Cge%201%20-%20%5Cdelta%0A"> where <img src="https://latex.codecogs.com/png.latex?%5CPr(A)"> is the probability of event <img src="https://latex.codecogs.com/png.latex?A">, <img src="https://latex.codecogs.com/png.latex?%5Cdelta%20%5Cin%20(0,%201%5D"> is the confidence parameter, and <img src="https://latex.codecogs.com/png.latex?r(m,%20%5Cdelta)"> – a function of <em>sample size</em> <img src="https://latex.codecogs.com/png.latex?m"> and the confidence parameter <img src="https://latex.codecogs.com/png.latex?%5Cdelta"> – is the <em>regularisation</em> that is satisfied: <img src="https://latex.codecogs.com/png.latex?%0A%5Clim_%7Bm%20%5Cto%20+%5Cinfty%7D%20r(m,%20%5Cdelta)%20=%200.%0A"> PAC-Bayes upper generalisation bound is a kind of PAC learning. It was firstly proposed in 1999 <span class="citation" data-cites="mcallester1999pac">McAllester (1999)</span>, and has attracted much of research interest. There has been many subsequent improvements made to tighten further this classic PAC-Bayes bound or to extend it to more general loss functions. However, the classic PAC-Bayes theorem is still the backbone. In this post, I will show how to prove this interesting theorem.</p>
<section id="auxillary-lemmas" class="level2" data-number="1">
<h2 data-number="1" class="anchored" data-anchor-id="auxillary-lemmas"><span class="header-section-number">1</span> Auxillary lemmas</h2>
<p>To prove the classic PAC-Bayes theorem, we need two auxilliary lemmas shown below.</p>
<section id="change-of-measure-inequality-for-kullback-leibler-divergence" class="level3" data-number="1.1">
<h3 data-number="1.1" class="anchored" data-anchor-id="change-of-measure-inequality-for-kullback-leibler-divergence"><span class="header-section-number">1.1</span> Change of measure inequality for Kullback-Leibler divergence</h3>
<div id="lem-change-of-measure" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 1</strong></span> <span class="citation" data-cites="banerjee2006bayesian">(Banerjee 2006 - Lemma 1)</span> For any measurable function <img src="https://latex.codecogs.com/png.latex?%5Cphi(h)"> on a set of predictor under consideration <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BH%7D">, and any distributions <img src="https://latex.codecogs.com/png.latex?P"> and <img src="https://latex.codecogs.com/png.latex?Q"> on <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BH%7D">, the following inequality holds: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D_%7BQ%7D%20%5B%5Cphi(h)%5D%20%5Cle%20%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20P%5D%20+%20%5Cln%20%5Cmathbb%7BE%7D_%7BP%7D%20%5B%5Cexp(%5Cphi(h))%5D.%0A"> Further, <img src="https://latex.codecogs.com/png.latex?%0A%5Csup_%7B%5Cphi%7D%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5B%5Cphi(h)%5D%20-%20%5Cln%20%5Cmathbb%7BE%7D_%7BP%7D%20%5B%5Cexp(%5Cphi(h))%5D%20=%20%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20P%5D.%0A"></p>
</div>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>For any measurable function <img src="https://latex.codecogs.com/png.latex?%5Cphi(h)">, the following holds: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5B%5Cphi(h)%5D%20&amp;%20=%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20%5Cln%20%5Cleft(%20%5Cexp(%5Cphi(h))%20%5Cfrac%7BQ(h)%7D%7BP(h)%7D%20%5Cfrac%7BP(h)%7D%7BQ(h)%7D%20%5Cright)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20P%5D%20+%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20%5Cln%20%5Cleft(%20%5Cexp(%5Cphi(h))%20%5Cfrac%7BP(h)%7D%7BQ(h)%7D%20%5Cright)%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20%5Cle%20%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20P%5D%20+%20%5Cln%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20%5Cexp(%5Cphi(h))%20%5Cfrac%7BP(h)%7D%7BQ(h)%7D%20%5Cright%5D%20%5C%5C%0A%20%20%20%20&amp;%20%5Cqquad%20%5Ctext%7B(Jensen's%20inequality)%7D%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20P%5D%20+%20%5Cln%20%5Cmathbb%7BE%7D_%7BP%7D%20%5Cleft%5B%20%5Cexp(%5Cphi(h))%20%5Cright%5D.%0A%5Cend%7Baligned%7D%0A"></p>
<p>For the second part of the lemma, we need to examine the equality condition of the Jensen’s inequality. Since <img src="https://latex.codecogs.com/png.latex?%5Cln(x)"> is a strictly concave function for <img src="https://latex.codecogs.com/png.latex?x%20%3E%200">, it follows that the equality holds when: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cexp%20%5Cleft(%20%5Cphi(h)%20%5Cright)%20&amp;%20%5Cfrac%7BP(h)%7D%7BQ(h)%7D%20=%201%20%5C%5C%0A%20%20%20%20%5Ciff%20%5Cphi(h)%20&amp;%20=%20%5Cln%20%5Cleft%5B%20%5Cfrac%7BQ(h)%7D%7BP(h)%7D%20%5Cright%5D.%0A%5Cend%7Baligned%7D%0A"> With this choice of <img src="https://latex.codecogs.com/png.latex?%5Cphi(h)">, we can verify that the equality does hold.</p>
<p>This completes the proof.</p>
</div>
</section>
<section id="concentration-inequality" class="level3" data-number="1.2">
<h3 data-number="1.2" class="anchored" data-anchor-id="concentration-inequality"><span class="header-section-number">1.2</span> Concentration inequality</h3>
<div id="lem-concentration-inequality" class="theorem lemma">
<p><span class="theorem-title"><strong>Lemma 2</strong></span> <span class="citation" data-cites="shalev2014understanding">(Shalev-Shwartz and Ben-David 2014 - Exercise 31.1)</span> Let <img src="https://latex.codecogs.com/png.latex?X"> be a random variable that satisfies: <img src="https://latex.codecogs.com/png.latex?%5Cmathrm%7BPr%7D%20(X%20%5Cge%20%5Cepsilon)%20%5Cle%20e%5E%7B-2m%20%5Cepsilon%5E%7B2%7D%7D">. Prove that <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D%20%5Cleft%5B%20e%5E%7B2(m%20-%201)%20X%5E%7B2%7D%7D%20%5Cright%5D%20%5Cle%20m.%0A"></p>
</div>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>Since the assumption is expressed in term of probability, while the conclusion is written in form of an expectation, what we need to do first is to try to present the expectation in terms of probability.</p>
<p>For simplicity, let <img src="https://latex.codecogs.com/png.latex?Y%20=%20e%5E%7B2(m%20-%201)%20X%5E%7B2%7D%7D">. Since <img src="https://latex.codecogs.com/png.latex?X%20%5Cin%20%5B0,%20+%5Cinfty)">, then <img src="https://latex.codecogs.com/png.latex?Y%20%5Cin%20%5B1,%20+%5Cinfty)"> and <img src="https://latex.codecogs.com/png.latex?Y"> can be presented as: <img src="https://latex.codecogs.com/png.latex?%0AY%20=%20%5Cint_%7B1%7D%5E%7B+%5Cinfty%7D%20%5Cpmb%7B1%7D(Y%20%5Cge%20t)%20%5C,%20%5Cmathrm%7Bd%7Dt%20+%201,%0A"> where <img src="https://latex.codecogs.com/png.latex?%5Cpmb%7B1%7D(A)"> is the indication function of event <img src="https://latex.codecogs.com/png.latex?A">. Note that the integral above is the area of a rectangle with height as 1 and the width <img src="https://latex.codecogs.com/png.latex?Y%20-%201">.</p>
<p>One important property of the indication function is that: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D%20%5Cleft%5B%20%5Cpmb%7B1%7D(Y%20%5Cge%20t)%20%5Cright%5D%20=%20%5Cmathrm%7BPr%7D(Y%20%5Cge%20t).%0A"> This allows to express the expectation of interest as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%5Cmathbb%7BE%7D%5BY%5D%20&amp;%20=%20%5Cmathbb%7BE%7D%20%5Cleft%5B%20%5Cint_%7B1%7D%5E%7B+%5Cinfty%7D%20%5Cpmb%7B1%7D(Y%20%5Cge%20t)%20%5C,%20%5Cmathrm%7Bd%7Dt%20%5Cright%5D%20+%201%20%5C%5C%0A&amp;%20=%20%5Cint_%7B1%7D%5E%7B+%5Cinfty%7D%20%5Cmathbb%7BE%7D%20%5B%5Cpmb%7B1%7D(Y%20%5Cge%20t)%5D%20%5C,%20%5Cmathrm%7Bd%7Dt%20+%201%20%5Cquad%20%5Ctext%7B(Fubini's%20theorem)%7D%20%5C%5C%0A&amp;%20=%20%5Cint_%7B1%7D%5E%7B+%5Cinfty%7D%20%5Cmathrm%7BPr%7D(Y%20%5Cge%20t)%20%5C,%20%5Cmathrm%7Bd%7Dt%20+%201.%0A%5Cend%7Baligned%7D%0A"> Or: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D%20%5Cleft%5B%20e%5E%7B2(m%20-%201)%20X%5E%7B2%7D%7D%20%5Cright%5D%20=%20%5Cint_%7B1%7D%5E%7B+%5Cinfty%7D%20%5Cmathrm%7BPr%7D(%20e%5E%7B2(m%20-%201)%20X%5E%7B2%7D%7D%20%5Cge%20x)%20%5C,%20%5Cmathrm%7Bd%7Dx%20+%201.%0A"></p>
<p>We then make a change of variable from <img src="https://latex.codecogs.com/png.latex?x"> to <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"> to utilise the given inequality in the assumption. Let’s define: <img src="https://latex.codecogs.com/png.latex?%0Ax%20=%20e%5E%7B2(m%20-%201)%20%5Cepsilon%5E%7B2%7D%7D.%0A"> Since <img src="https://latex.codecogs.com/png.latex?%5Cepsilon"> is assumed to be non-negative, we can express it as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cepsilon%20=%20%5Csqrt%7B%5Cfrac%7B%5Cln%20x%7D%7B2(m%20-%201)%7D%7D,%0A"> and: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathrm%7Bd%7Dx%20=%204(m%20-%201)%20%5Cepsilon%20%5C,%20e%5E%7B2(m%20-%201)%20%5Cepsilon%5E%7B2%7D%7D%20%5C,%20%5Cmathrm%7Bd%7D%20%5Cepsilon.%0A"></p>
<p>The expectation of interest can, therefore, be written as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%20%20%20%20%5Cmathbb%7BE%7D%20%5Cleft%5B%20e%5E%7B2(m%20-%201)%20X%5E%7B2%7D%7D%20%5Cright%5D%20&amp;%20=%20%5Cint_%7B0%7D%5E%7B+%5Cinfty%7D%20%5Cmathrm%7BPr%7D%20%5Cleft(%20e%5E%7B2(m%20-%201)%20X%5E%7B2%7D%7D%20%5Cge%20e%5E%7B2(m%20-%201)%20%5Cepsilon%5E%7B2%7D%7D%20%5Cright)%204(m%20-%201)%20%5Cepsilon%20%5C,%20e%5E%7B2(m%20-%201)%20%5Cepsilon%5E%7B2%7D%7D%20%5C,%20%5Cmathrm%7Bd%7D%20%5Cepsilon%20%20+%201%5C%5C%0A%20%20%20%20&amp;%20=%20%5Cint_%7B0%7D%5E%7B+%5Cinfty%7D%20%5Cmathrm%7BPr%7D%20%5Cunderbrace%7B%5Cleft(%20X%20%5Cge%20%5Cepsilon%20%5Cright)%7D_%7B%5Cle%20e%5E%7B-2m%5Cepsilon%5E%7B2%7D%7D%7D%204(m%20-%201)%20%5Cepsilon%20%5C,%20e%5E%7B2(m%20-%201)%20%5Cepsilon%5E%7B2%7D%7D%20%5C,%20%5Cmathrm%7Bd%7D%20%5Cepsilon%20+%201%5C%5C%0A%20%20%20%20&amp;%20%5Cle%204(m%20-%201)%20%5Cint_%7B0%7D%5E%7B+%5Cinfty%7D%20%5Cepsilon%20%5C,%20e%5E%7B-2%20%5Cepsilon%5E%7B2%7D%7D%20%5C,%20%5Cmathrm%7Bd%7D%20%5Cepsilon%20+%201%20=%20m.%0A%5Cend%7Baligned%7D%0A"></p>
</div>
</section>
</section>
<section id="pac-bayes-bound" class="level2" data-number="2">
<h2 data-number="2" class="anchored" data-anchor-id="pac-bayes-bound"><span class="header-section-number">2</span> PAC-Bayes bound</h2>
<div id="thm-pac-bayes-bound" class="theorem">
<p><span class="theorem-title"><strong>Theorem 1</strong></span> Let <img src="https://latex.codecogs.com/png.latex?D"> be an arbitrary distribution over an example domain <img src="https://latex.codecogs.com/png.latex?Z">. Let <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BH%7D"> be a hypothesis class, <img src="https://latex.codecogs.com/png.latex?%5Cell:%20%5Cmathcal%7BH%7D%20%5Ctimes%20Z%20%5Cto%20%5B0,%201%5D"> be a loss function, <img src="https://latex.codecogs.com/png.latex?%5Cpi"> be a prior distribution over <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BH%7D">, and <img src="https://latex.codecogs.com/png.latex?%5Cdelta%20%5Cin%20(0,%201%5D">. If <img src="https://latex.codecogs.com/png.latex?S%20=%20%5C%7Bz_j%5C%7D_%7Bj=1%7D%5E%7Bm%7D"> is an i.i.d. training set sampled according to <img src="https://latex.codecogs.com/png.latex?D">, then for any “posterior” <img src="https://latex.codecogs.com/png.latex?Q"> over <img src="https://latex.codecogs.com/png.latex?%5Cmathcal%7BH%7D">, the following holds: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathrm%7BPr%7D%20%5Cleft(%20%5Cmathbb%7BE%7D_%7Bz_%7Bj%7D%20%5Csim%20D%7D%20%5Cmathbb%7BE%7D_%7Bh%20%5Csim%20Q%7D%20%5Cleft%5B%20%5Cell(h,%20z_%7Bj%7D)%20%5Cright%5D%20%5Cle%20%5Cmathbb%7BE%7D_%7Bz_%7Bj%7D%20%5Csim%20S%7D%20%5Cmathbb%7BE%7D_%7Bh%20%5Csim%20Q%7D%20%5Cleft%5B%20%5Cell(h,%20z_%7Bj%7D)%20%5Cright%5D%20+%20%5Csqrt%7B%5Cfrac%7B%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20%5Cpi%5D%20+%20%5Cfrac%7B%5Cln%20m%7D%7B%5Cdelta%7D%7D%7B2(m%20-%201)%7D%7D%20%5Cright)%20%5Cge%201%20-%20%5Cdelta.%0A"></p>
</div>
<div class="proof">
<p><span class="proof-title"><em>Proof</em>. </span>We define some notations to ease the proving: - <img src="https://latex.codecogs.com/png.latex?L%20=%20%5Cmathbb%7BE%7D_%7Bz_%7Bj%7D%20%5Csim%20D%7D%20%5Cleft%5B%20%5Cell(h,%20z_%7Bj%7D)%20%5Cright%5D"> - <img src="https://latex.codecogs.com/png.latex?%5Chat%7BL%7D%20=%20%5Cmathbb%7BE%7D_%7Bz_%7Bj%7D%20%5Csim%20S%7D%20%5Cleft%5B%20%5Cell(h,%20z_%7Bj%7D)%20%5Cright%5D%20=%20%5Cfrac%7B1%7D%7Bm%7D%20%5Csum_%7Bj=1%7D%5E%7Bm%7D%20%5Cell(h,%20z_%7Bj%7D)"> - <img src="https://latex.codecogs.com/png.latex?%5CDelta%20L%20=%20L%20-%20%5Chat%7BL%7D"></p>
<p>Applying Lemma&nbsp;1 with <img src="https://latex.codecogs.com/png.latex?P(h)%20=%20%5Cpi%20(h)"> and <img src="https://latex.codecogs.com/png.latex?%5Cphi(h)%20=%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D"> gives: <span id="eq-lower-bound-log_expect"><img src="https://latex.codecogs.com/png.latex?%0A2(m%20-%201)%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright%5D%20-%20%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20%5Cpi%5D%20%5Cle%20%5Ctextcolor%7Bpurple%7D%7B%5Cln%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%7D.%0A%5Ctag%7B1%7D"></span></p>
<p>We upper-bound the last term in the RHS (highlighted in <span style="color: purple;">purple</span> colour) by Lemma&nbsp;2. To do that, we consider the empirical loss on each observable data point <img src="https://latex.codecogs.com/png.latex?l(h,%20z_%7Bj%7D)"> as a random variable in <img src="https://latex.codecogs.com/png.latex?%5B0,%201%5D"> with true and empirical means <img src="https://latex.codecogs.com/png.latex?L"> and <img src="https://latex.codecogs.com/png.latex?%5Chat%7BL%7D">, respectively. Following the Hoeffding’s inequality gives: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%5Cmathrm%7BPr%7D%20%5Cleft(%20%5CDelta%20L%20%5Cge%20%5Cepsilon%20%5Cright)%20&amp;%20=%20%5Cmathrm%7BPr%7D%20%5Cleft(%20L%20-%20%5Chat%7BL%7D%20%5Cge%20%5Cepsilon%20%5Cright)%5C%5C%0A&amp;%20%5Cle%20%5Cmathrm%7BPr%7D%20%5Cleft(%20%7C%20L%20-%20%5Chat%7BL%7D%20%7C%20%5Cge%20%5Cepsilon%20%5Cright)%5C%5C%0A&amp;%20%5Cle%20%5Cexp(-2m%20%5Cepsilon%5E%7B2%7D),%20%5Cquad%20%5Cepsilon%20%5Cge%200.%0A%5Cend%7Baligned%7D%0A"> According to Lemma&nbsp;2, this implies: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D_%7BS%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%20%5Cle%20m.%0A"> Taking the expectation w.r.t. <img src="https://latex.codecogs.com/png.latex?h%20%5Csim%20%5Cpi(h)"> on both sides and applying Fubini’s theorem (to interchange the 2 expectations) gives: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A&amp;%20%5Cmathbb%7BE%7D_%7BS%7D%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%20%5Cle%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%20m%20%5Cright%5D%20=%20m%5C%5C%0A&amp;%20%5Cimplies%20%5Cln%20%5Cmathbb%7BE%7D_%7BS%7D%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%20%5Cle%20%5Cln%20m%5C%5C%0A&amp;%20%5Cimplies%20%5Cmathbb%7BE%7D_%7BS%7D%20%5Ctextcolor%7Bpurple%7D%7B%5Cln%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%7D%20%5Cle%20%5Cln%20m.%0A%5Cend%7Baligned%7D%0A"> Note that the last implication is due to Jensen’s inequality.</p>
<p>We then apply Markov’s inequality for the term highlighted in <span style="color: purple;">purple</span>: <img src="https://latex.codecogs.com/png.latex?%0A%5Cbegin%7Baligned%7D%0A%5Cmathrm%7BPr%7D%20%5Cleft(%20%5Ctextcolor%7Bpurple%7D%7B%5Cln%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%7D%20%5Cge%20%5Cvarepsilon%20%5Cright)%20&amp;%20%5Cle%20%5Cfrac%7B%5Cmathbb%7BE%7D_%7BS%7D%20%5Ctextcolor%7Bpurple%7D%7B%5Cln%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%7D%7D%7B%5Cvarepsilon%7D%20%5C%5C%0A&amp;%20%5Cle%20%5Cfrac%7B%5Cln%20m%7D%7B%5Cvarepsilon%7D.%0A%5Cend%7Baligned%7D%0A"></p>
<p>This implies: <span id="eq-bound_log_expect_prob"><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathrm%7BPr%7D%20%5Cleft(%20%5Ctextcolor%7Bpurple%7D%7B%5Cln%20%5Cmathbb%7BE%7D_%7B%5Cpi%7D%20%5Cleft%5B%5Cexp%20%5Cleft(%202(m%20-%201)%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright)%20%5Cright%5D%7D%20%5Cle%20%5Cvarepsilon%20%5Cright)%20%5Cge%201%20-%20%5Cfrac%7B%5Cln%20m%7D%7B%5Cvarepsilon%7D.%0A%5Ctag%7B2%7D"></span></p>
<p>Combining the results in Equation&nbsp;1 and Equation&nbsp;2 gives: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathrm%7BPr%7D%20%5Cleft(%202(m%20-%201)%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright%5D%20-%20%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20%5Cpi%5D%20%5Cle%20%5Cvarepsilon%20%5Cright)%20%5Cge%201%20-%20%5Cfrac%7B%5Cln%20m%7D%7B%5Cvarepsilon%7D.%0A"></p>
<p>This is equivalent to: <span id="eq-almost-done"><img src="https://latex.codecogs.com/png.latex?%0A%5Cmathrm%7BPr%7D%20%5Cleft(%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright%5D%20%5Cle%20%5Cfrac%7B%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20%5Cpi%5D%20+%20%5Cvarepsilon%7D%7B2(m%20-%201)%7D%20%5Cright)%20%5Cge%201%20-%20%5Cfrac%7B%5Cln%20m%7D%7B%5Cvarepsilon%7D.%0A%5Ctag%7B3%7D"></span></p>
<p>Note that squared function is a strictly concave function, resulting in: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20(%5CDelta%20L)%5E%7B2%7D%20%5Cright%5D%20%5Cge%20%5Cleft(%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20%5CDelta%20L%20%5Cright%5D%20%5Cright)%5E%7B2%7D.%0A"></p>
<p>Hence, Equation&nbsp;3 can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%5Cmathrm%7BPr%7D%20%5Cleft(%20%5Cmathbb%7BE%7D_%7BQ%7D%20%5Cleft%5B%20%5CDelta%20L%20%5Cright%5D%20%5Cle%20%5Csqrt%7B%5Cfrac%7B%5Cmathrm%7BKL%7D%20%5BQ%20%5CVert%20%5Cpi%5D%20+%20%5Cvarepsilon%7D%7B2(m%20-%201)%7D%7D%20%5Cright)%20%5Cge%201%20-%20%5Cfrac%7B%5Cln%20m%7D%7B%5Cvarepsilon%7D.%0A"></p>
<p>Seting <img src="https://latex.codecogs.com/png.latex?%5Cdelta%20=%20%5Cfrac%7B%5Cln%20m%7D%7B%5Cvarepsilon%7D">, and expanding <img src="https://latex.codecogs.com/png.latex?%5CDelta%20L"> according to its definition complete the proof.</p>
</div>
</section>
<section id="discussion" class="level2" data-number="3">
<h2 data-number="3" class="anchored" data-anchor-id="discussion"><span class="header-section-number">3</span> Discussion</h2>
<p>AFAIK, the result in Theorem&nbsp;1 is a seminal PAC-Bayes bound in the literature of PAC learning. Readers could refer subsequent derivations of tighter PAC-Bayes bounds developed later.</p>
</section>
<section id="references" class="level2" data-number="4">
<h2 data-number="4" class="anchored" data-anchor-id="references"><span class="header-section-number">4</span> References</h2>
<div id="refs" class="references csl-bib-body hanging-indent">
<div id="ref-banerjee2006bayesian" class="csl-entry">
Banerjee, Arindam. 2006. <span>“On Bayesian Bounds.”</span> <em>International Conference on Machine Learning</em>, 81–88.
</div>
<div id="ref-mcallester1999pac" class="csl-entry">
McAllester, David A. 1999. <span>“PAC-Bayesian Model Averaging.”</span> <em>Conference on Computational Learning Theory</em>, 164–70.
</div>
<div id="ref-shalev2014understanding" class="csl-entry">
Shalev-Shwartz, Shai, and Shai Ben-David. 2014. <em>Understanding Machine Learning: From Theory to Algorithms</em>. Cambridge university press.
</div>
</div>


<!-- -->

</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a><div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by/4.0/">CC BY 4.0</a></div></div></section><section class="quarto-appendix-contents" id="quarto-citation"><h2 class="anchored quarto-appendix-heading">Citation</h2><div><div class="quarto-appendix-secondary-label">BibTeX citation:</div><pre class="sourceCode code-with-copy quarto-appendix-bibtex"><code class="sourceCode bibtex">@online{nguyen2020,
  author = {Nguyen, Cuong},
  title = {PAC-Bayes Bounds for Generalisation Error},
  date = {2020-12-26},
  url = {https://cnguyen10.github.io/posts/PAC-Bayes-bounds/},
  langid = {en}
}
</code></pre><div class="quarto-appendix-secondary-label">For attribution, please cite this work as:</div><div id="ref-nguyen2020" class="csl-entry quarto-appendix-citeas">
Nguyen, Cuong. 2020. <span>“PAC-Bayes Bounds for Generalisation
Error.”</span> December 26. <a href="https://cnguyen10.github.io/posts/PAC-Bayes-bounds/">https://cnguyen10.github.io/posts/PAC-Bayes-bounds/</a>.
</div></div></section></div> ]]></description>
  <category>Statistical Learning Theory</category>
  <category>Generalisation</category>
  <guid>https://cnguyen10.github.io/posts/PAC-Bayes-bounds/</guid>
  <pubDate>Sat, 26 Dec 2020 00:00:00 GMT</pubDate>
</item>
<item>
  <title>VAE: normalising constant matters</title>
  <dc:creator>Cuong Nguyen</dc:creator>
  <link>https://cnguyen10.github.io/posts/vae-normalising-constant-matters/</link>
  <description><![CDATA[ 




<p>Variational auto-encoder (VAE) is one of the most popular generative models in machine learning nowadays. However, the rapid development of the field has made many machine learning practitioners (or, maybe only me) focus too much on deep learning without paying much attention to some fundamentals, such as linear regression. That causes much confusion due to the discrepancy between the derivation and the practical implementation, in which the regularization of the loss, or specifically the Kullback-Leibler (KL) divergence, is weighted by some factor <img src="https://latex.codecogs.com/png.latex?%5Cbeta">. I myself did experience and struggle at the beginning of my research. Even though weighting the KL divergence term by a factor $ $ could temporarily resolve the issue, I has been questioning why the balancing between reconstruction and KL divergence is necessary. Eventually, the answer is quite simple: the normalising constant in the reconstruction loss (or negative log-likelihood) that has been often ignored. This ignorance is the main cause of the imbalance between the two losses.</p>
<section id="variational-auto-encoder" class="level2" data-number="1">
<h2 data-number="1" class="anchored" data-anchor-id="variational-auto-encoder"><span class="header-section-number">1</span> Variational auto-encoder</h2>
<p>Given data points <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D%20=%20%5C%7Bx_%7Bi%7D%5C%7D_%7Bn=1%7D%5E%7BN%7D">, the model of a VAE assumes that there is a corresponding latent variable <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bz%7D%20=%20%5C%7B%20z_%7Bn%7D%20%5C%7D_%7Bn=1%7D%5E%7BN%7D"> that generates data <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D">. In short, the objective function of a VAE is to minimise the variational-free energy (VFE) given as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmin_%7Bq%7D%20%5Cunderbrace%7B%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20-%20%5Cln%20p(%5Cmathbf%7Bx%7D%20%7C%20%5Cmathbf%7Bz%7D)%20%5Cright%5D%7D_%7B%5Ctext%7Breconstruction%20loss%7D%7D%20+%20%5Ctextcolor%7Bred%7D%7B%5Cbeta%7D%20%5Cmathrm%7BKL%7D%20%5Cleft%5B%20q(%5Cmathbf%7Bz%7D)%20%5CVert%20p(%5Cmathbf%7Bx%7D)%20%5Cright%5D,%20%5Ctag%7Bvfe%7D%0A"> where <img src="https://latex.codecogs.com/png.latex?q(%5Cmathbf%7Bz%7D)"> is the variational distribution of the latent variable, and <img src="https://latex.codecogs.com/png.latex?%5Ctextcolor%7Bred%7D%7B%5Cbeta%7D%20=%201"> is the weighting factor.</p>
<p>In practice, people often “specify” the reconstruction loss as mean squared error (MSE) or binary cross-entropy loss and use gradient descent to minimise VFE. With <img src="https://latex.codecogs.com/png.latex?%5Cbeta%20=%201"> as in (vfe), the reconstruction of different images seem to be the same image (see Figure 1 (top)), whereas setting $ $ results in much better reconstructed images (see Figure 1 (bottom)).</p>
<figure class="figure">
<img src="https://i.stack.imgur.com/QKrOM.jpg" alt="same reconstructed images" style="width:100%" class="figure-img"> <img src="https://i.stack.imgur.com/63xvp.jpg" alt="decent reconstructed images" style="width:100%" class="figure-img">
<figcaption>
Figure 1. The reconstructed images from VAE with β = 1 (top) and β ≪ 1 (bottom). Source: <a href="https://stats.stackexchange.com/questions/341954/balancing-reconstruction-vs-kl-loss-variational-autoencoder">stats.stackexchange.com</a>
</figcaption>
</figure>
<p>This does not make me satisfied, although some justifications for setting <img src="https://latex.codecogs.com/png.latex?%5Cbeta"> to some small value are made. For example: - Setting <img src="https://latex.codecogs.com/png.latex?%5Cbeta%20%5Cll%201"> leads to even a “further lower-bound”. Hence, maximizing this “further lower-bound” is still mathematically reasonable. However, this bound is very loose. Can we do something better? - One can cast the problem to a constrained optimisation as in <a href="https://openreview.net/forum?id=Sy2fzU9gl">β-VAE paper</a>. However, β in that case is the Lagrange multiplier, and should be obtained through the optimisation. Is it mathematically correct if considering β as a hyper-parameter? I doubt that.</p>
<p>Later on, I figure out that the main reason of the imbalance between the two losses is due to the “specification” of the reconstruction loss. Simply specifying the type of the loss <img src="https://latex.codecogs.com/png.latex?-%5Cln%20p(%5Cmathbf%7Bx%7D%20%5Cvert%20%5Cmathbf%7Bz%7D)"> as MSE or binary cross-entropy would ignore the normalising constant, resulting in an incorrect reconstruction loss. The correct way is to specify the modelling assumption of the likelihood <img src="https://latex.codecogs.com/png.latex?p(%5Cmathbf%7Bx%7D%20%5Cvert%20%5Cmathbf%7Bz%7D)">, which, in the case of VAE, goes back to linear regression.</p>
<p>In the following sections, <img src="https://latex.codecogs.com/png.latex?f(%5Cmathbf%7Bz%7D;%20%5Ctheta)"> denotes the output of the decoder parameterized by a neural network with weight <img src="https://latex.codecogs.com/png.latex?%5Ctheta">. Usually, <img src="https://latex.codecogs.com/png.latex?f(%5Cmathbf%7Bz%7D;%20%5Ctheta)"> is assumed to be the reconstructed images, but this might not always true depending on the assumption used.</p>
</section>
<section id="reconstruction-likelihood-with-gaussian-assumption" class="level2" data-number="2">
<h2 data-number="2" class="anchored" data-anchor-id="reconstruction-likelihood-with-gaussian-assumption"><span class="header-section-number">2</span> Reconstruction likelihood with Gaussian assumption</h2>
<p>This corresponds to linear regression with Gaussian noise assumption.</p>
<p>The variable of interest <img src="https://latex.codecogs.com/png.latex?%5Cmathbf%7Bx%7D"> is assumed to be a deterministic function <img src="https://latex.codecogs.com/png.latex?f(%5Cmathbf%7Bz%7D;%20%5Ctheta)"> with additional Gaussian noise, so that: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cmathbf%7Bx%7D%20=%20f(%5Cmathbf%7Bz%7D;%20%5Ctheta)%20+%20%5Cepsilon,%0A"> where: <img src="https://latex.codecogs.com/png.latex?%5Cepsilon%20%5Csim%20%5Cmathcal%7BN%7D%5Cleft(%20%5Cepsilon;%200,%20%5CLambda%5E%7B-1%7D%20%5Cright)">. Thus, the reconstruction likelihood can be written as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20p(%5Cmathbf%7Bx%7D%20%5Cvert%20%5Cmathbf%7Bz%7D,%20%5Ctheta,%20%5CLambda)%20=%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D;%20f(%5Cmathbf%7Bz%7D;%20%5Ctheta),%20%5CLambda%5E%7B-1%7D)%20=%20%5Cprod_%7Bn=1%7D%5E%7BN%7D%20%5Cmathcal%7BN%7D(x_%7Bn%7D;%20f(z_%7Bn%7D;%20%5Ctheta),%20%5CLambda%5E%7B-1%7D).%0A"> Hence, the negative log-likelihood, or the reconstruction loss in the VAE, can be expressed as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20-%5Cln%20p(%5Cmathbf%7Bx%7D%20%5Cvert%20%5Cmathbf%7Bz%7D,%20%5Ctheta,%20%5CLambda)%20=%20-%20%5Cfrac%7BN%7D%7B2%7D%20%5Cln%20%5Cfrac%7B%5CLambda%7D%7B2%20%5Cpi%7D%20+%20%5CLambda%20%5Ctimes%20%5Cfrac%7B1%7D%7B2%7D%20%5Cunderbrace%7B%5Csum_%7Bn=1%7D%5E%7BN%7D%20%5Cleft%5B%20x_%7Bn%7D%20-%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright%5D%5E%7B2%7D%7D_%7BN%20%5Ctimes%20%5Ctext%7BMSE%7D%7D.%20%5Ctag%7Bnll-G%7D%0A"></p>
<blockquote class="blockquote">
<p>Note that current practice uses only MSE, which ignores the first term and the scaling factor relating to the noise precision <img src="https://latex.codecogs.com/png.latex?%5CLambda">.</p>
</blockquote>
<p>Under this modelling approach, the decoder would consist of 2 networks: one for mean <img src="https://latex.codecogs.com/png.latex?%5Cbar%7Bx%7D%20=%20f(z;%20%5Ctheta)"> and the other for noise precision <img src="https://latex.codecogs.com/png.latex?%5CLambda%20=%20g(z;%20%5Cphi)">. Of course, one can consider <img src="https://latex.codecogs.com/png.latex?%5CLambda"> as a hyper-parameter to simplify further the implementation.</p>
<p>The “full” loss function of a VAE is, therefore, presented as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cboxed%7B%0A%20%20%20%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Cfrac%7BN%7D%7B2%7D%20%5Cln(2%5Cpi)%20-%20%5Cfrac%7BN%7D%7B2%7D%20%5Cln%20%5CLambda%20+%20%5Cfrac%7B%5CLambda%7D%7B2%7D%20%5Csum_%7Bn=1%7D%5E%7BN%7D%20%5Cleft%5B%20x_%7Bn%7D%20-%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright%5D%5E%7B2%7D%20%5Cright%5D%20+%20%5Cmathrm%7BKL%7D%20%5Cleft%5B%20q(%5Cmathbf%7Bz%7D)%20%5CVert%20p(%5Cmathbf%7Bx%7D)%20%5Cright%5D.%20%5Ctag%7Bvfe-G%7D%0A%20%20%20%20%7D%0A"></p>
<p>After training, one can pass an image to the encoder <img src="https://latex.codecogs.com/png.latex?h(.;%20%5Cphi)"> and decoder to get the predicted mean and precision. The reconstructed images can then be obtained as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Chat%7Bx%7D%20%5Csim%20%5Cmathcal%7BN%7D(x;%20f(z;%20%5Ctheta),%20%5CLambda),%20%5Ctext%7Bwhere%20%7D%20z%20=%20h(x;%20%5Cphi).%0A"> Although this approach is easy to understand, one drawback is the unbounded support of the Gaussian distribution, resulting in reconstructed pixel intensity values out of the desired range <img src="https://latex.codecogs.com/png.latex?%5B0,%201%5D">. Consequently, when visualizing, the pixels that are out of that range will be truncated to 0 or 1, potentially making the reconstructed images blurrier.</p>
</section>
<section id="reconstruction-likelihood-with-continuous-bernoulli-assumption" class="level2" data-number="3">
<h2 data-number="3" class="anchored" data-anchor-id="reconstruction-likelihood-with-continuous-bernoulli-assumption"><span class="header-section-number">3</span> Reconstruction likelihood with continuous Bernoulli assumption</h2>
<p>This corresponding to linear regression in <img src="https://latex.codecogs.com/png.latex?%5B0,%201%5D"> (not $ {0, 1 } $ as in logistic regression), and hence, the words “continuous Bernoulli”.</p>
<p>This modelling approach is not as intuitive as the one with Gaussian assumption, but please bear with me for a moment.</p>
<p>The likelihood of interest, <img src="https://latex.codecogs.com/png.latex?p(%5Cmathbf%7Bx%7D%20%5Cvert%20%5Cmathbf%7Bz%7D)">, is assumed to be a <a href="https://papers.nips.cc/paper/2019/hash/f82798ec8909d23e55679ee26bb26437-Abstract.html">continuous Bernoulli distribution</a>: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20p(%5Cmathbf%7Bx%7D%20%5Cvert%20%5Cmathbf%7Bz%7D)%20=%20%5Cmathcal%7BCB%7D(%5Cmathbf%7Bx%7D;%20f(%5Cmathbf%7Bz%7D;%20%5Ctheta))%20=%20%5Cprod_%7Bn=1%7D%5E%7BN%7D%20%5Cunderbrace%7BC%20%5Cleft(%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright)%7D_%7B%5Ctext%7Bnormalising%20const.%7D%7D%20%20%5Cunderbrace%7B%5Cleft%5B%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright%5D%5E%7Bx_%7Bn%7D%7D%20%5Cleft%5B%201%20-%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright%5D%5E%7B1%20-%20x_%7Bn%7D%7D%7D_%7B%5Ctext%7BBernoulli%20pdf%7D%7D,%0A"> and $f(z_{n}; )) , n {1, , N } $.</p>
<p>Note that: - the usage of continuous Bernoulli distribution is due to the fact that VAE tries to regress the pixel intensity <img src="https://latex.codecogs.com/png.latex?x_%7Bn%7D"> which falls in <img src="https://latex.codecogs.com/png.latex?%5B0,%201%5D">, not $ {0, 1 } $ as in classification, - the pdf of a continuous Bernoulli distribution differs from a Bernoulli distribution at the normalising constant term, - the output of the decoder now is not the mean of the reconstructed pixel intensity as in the case of Gaussian distribution, - due to the assumption of the continuous Bernoulli distribution, the last layer of the decoder must be activated by sigmoid function to ensure the output falling in $[0, 1] $.</p>
<p>The negative log-likelihood, or reconstruction loss, can be easily derived as: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20-%20%5Cln%20p(%5Cmathbf%7Bx%7D%20%5Cvert%20%5Cmathbf%7Bz%7D)%20=%20%5Csum_%7Bn=1%7D%5E%7BN%7D%20%5Cunderbrace%7B%20-%20%5Cleft%5B%20x_%7Bn%7D%20%5Cln%20f(z_%7Bn%7D;%20%5Ctheta)%20+%20(1%20-%20x_%7Bn%7D)%20%5Cln%20%5Cleft%5B1%20-%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright%5D%20%5Cright%5D%7D_%7B%5Ctext%7Bbinary%20cross-entropy%7D%7D%20-%20%5Cunderbrace%7B%5Cln%20C%20%5Cleft(%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright)%7D_%7B%5Ctext%7Blog%20normalising%20const.%7D%7D.%20%5Ctag%7Bnll-CB%7D%0A"></p>
<blockquote class="blockquote">
<p>Current practice uses binary cross-entropy loss only, corresponding to Bernoulli distribution. To me, that practice is not correct, since the learning is to infer the parameter of the Bernoulli distribution, which is the probability when the outcome is 1. In that case, the pixel intensity is in $ {0, 1 } $, not $[0, 1] $. This explains why VAE using binary cross-entropy loss often works well for grey-scale, but not colour, images.</p>
</blockquote>
<p>Substituting (nll-CB) into (vfe) gives the “full” objective function for VAE: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Cboxed%7B%0A%20%20%20%20%20%20%20%20%5Cbegin%7Baligned%7D%0A%20%20%20%20%20%20%20%20&amp;%20-%20%5Cmathbb%7BE%7D_%7Bq(%5Cmathbf%7Bz%7D)%7D%20%5Cleft%5B%20%5Csum_%7Bn=1%7D%5E%7BN%7D%20x_%7Bn%7D%20%5Cln%20f(z_%7Bn%7D;%20%5Ctheta)%20+%20(1%20-%20x_%7Bn%7D)%20%5Cln%20%5Cleft%5B1%20-%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright%5D%20%5Cright.%20%5C%5C%0A%20%20%20%20%20%20%20%20&amp;%20%5Cquad%20%5Cleft.%20+%20%5Cln%20C%20%5Cleft(%20f(z_%7Bn%7D;%20%5Ctheta)%20%5Cright)%20%5Cright%5D%20+%20%5Cmathrm%7BKL%7D%20%5Cleft%5B%20q(%5Cmathbf%7Bz%7D)%20%5CVert%20p(%5Cmathbf%7Bx%7D)%20%5Cright%5D.%0A%20%20%20%20%20%20%20%20%5Cend%7Baligned%7D%0A%20%20%20%20%20%20%20%20%5Ctag%7Bvfe-CB%7D%0A%20%20%20%20%7D%0A"></p>
<p>Note that after training, direct plotting <img src="https://latex.codecogs.com/png.latex?f(z;%20%5Ctheta)"> as the pixel intensity might result in an incorrect reconstructed image, since the mean of the continuous Bernoulli distribution is not equal to its parameter. To reconstruct an image <img src="https://latex.codecogs.com/png.latex?x">, one needs to pass that image through the encoder and decoder, and then: <img src="https://latex.codecogs.com/png.latex?%0A%20%20%20%20%5Chat%7Bx%7D%20%5Csim%20%5Cmathcal%7BCB%7D%5Cleft(x;%20f(z;%20%5Ctheta)%20%5Cright),%0A"> and plot <img src="https://latex.codecogs.com/png.latex?%5Chat%7Bx%7D"> to visualize the reconstructed image.</p>
</section>
<section id="conclusion" class="level2" data-number="4">
<h2 data-number="4" class="anchored" data-anchor-id="conclusion"><span class="header-section-number">4</span> Conclusion</h2>
<p>VAE is often considered as a basic generative model. However, most machine learning practitioners often learn by memorization about the “type” of reconstruction loss. This leads to the weighting trick in the implementation. Understanding the nature of the reconstruction loss as the log-likelihood in linear regression allows one to obtain the “full” objective function without applying any weighting tricks. Hopefully, this post would be useful to save time for ones who start to practise machine learning.</p>
</section>
<section id="references" class="level2" data-number="5">
<h2 data-number="5" class="anchored" data-anchor-id="references"><span class="header-section-number">5</span> References</h2>
<ol type="1">
<li>Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S. and Lerchner, A., 2016. <a href="https://openreview.net/forum?id=Sy2fzU9gl">β-VAE: Learning basic visual concepts with a constrained variational framework</a>. In International Conference on Learning Representation.</li>
<li>Loaiza-Ganem, G. and Cunningham, J.P., 2019. <a href="https://papers.nips.cc/paper/2019/hash/f82798ec8909d23e55679ee26bb26437-Abstract.html">The continuous Bernoulli: fixing a pervasive error in variational autoencoders</a>. In Advances in Neural Information Processing Systems (pp.&nbsp;13287-13297).</li>
</ol>


<!-- -->

</section>

<a onclick="window.scrollTo(0, 0); return false;" id="quarto-back-to-top"><i class="bi bi-arrow-up"></i> Back to top</a><div id="quarto-appendix" class="default"><section class="quarto-appendix-contents" id="quarto-reuse"><h2 class="anchored quarto-appendix-heading">Reuse</h2><div class="quarto-appendix-contents"><div><a rel="license" href="https://creativecommons.org/licenses/by/4.0/">CC BY 4.0</a></div></div></section><section class="quarto-appendix-contents" id="quarto-citation"><h2 class="anchored quarto-appendix-heading">Citation</h2><div><div class="quarto-appendix-secondary-label">BibTeX citation:</div><pre class="sourceCode code-with-copy quarto-appendix-bibtex"><code class="sourceCode bibtex">@online{nguyen2020,
  author = {Nguyen, Cuong},
  title = {VAE: Normalising Constant Matters},
  date = {2020-11-24},
  url = {https://cnguyen10.github.io/posts/vae-normalising-constant-matters/},
  langid = {en}
}
</code></pre><div class="quarto-appendix-secondary-label">For attribution, please cite this work as:</div><div id="ref-nguyen2020" class="csl-entry quarto-appendix-citeas">
Nguyen, Cuong. 2020. <span>“VAE: Normalising Constant Matters.”</span>
November 24. <a href="https://cnguyen10.github.io/posts/vae-normalising-constant-matters/">https://cnguyen10.github.io/posts/vae-normalising-constant-matters/</a>.
</div></div></section></div> ]]></description>
  <category>Deep Learning</category>
  <category>Generative Models</category>
  <category>Variational Inference</category>
  <guid>https://cnguyen10.github.io/posts/vae-normalising-constant-matters/</guid>
  <pubDate>Tue, 24 Nov 2020 00:00:00 GMT</pubDate>
</item>
</channel>
</rss>
